Skip to content

Commit 00c1d17

Browse files
wip
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent 93f9a2b commit 00c1d17

File tree

4 files changed

+119
-96
lines changed

4 files changed

+119
-96
lines changed

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def _insert_quantized_linear(
137137
The state_dict is also updated to contain the sharded weights.
138138
"""
139139
param_name, _ = extract_param_names_from_node(node)
140-
original_weight = gm.get_parameter(param_name)
140+
original_weight = gm.get_parameter(param_name[0])
141141
new_param = nn.Parameter(self.quantize_weight(original_weight), requires_grad=False)
142142
modname, _, attrname = param_name.rpartition(".")
143143

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 59 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@
3838
LayerSubgraph,
3939
LayerType,
4040
bfs,
41-
extract_param_names_from_node,
42-
extract_weight_node,
41+
extract_weight_nodes,
4342
filtered_nodes,
4443
get_all_layer_subgraphs,
4544
get_layer_after_linear_node,
@@ -48,7 +47,6 @@
4847
is_any_moe_op,
4948
is_any_ssm_op,
5049
is_op,
51-
num_users_of_weight_node,
5250
shape,
5351
subgraph,
5452
)
@@ -1330,68 +1328,70 @@ def _shard_parameter_node(
13301328

13311329
rank, world_size = config.rank, config.world_size
13321330
allreduce_strategy = config.allreduce_strategy.name
1333-
num_users = num_users_of_weight_node(node)
1334-
if num_users > 1 or num_users == 0:
1335-
ad_logger.warning(
1336-
f"Weight node {node} has {num_users} users. This is not supported for sharding. Skipping."
1337-
)
1338-
return
1339-
# get weight and bias key
1340-
weight_key, bias_key = extract_param_names_from_node(node)
1341-
1342-
modname = weight_key.rpartition(".")[0]
1343-
submod = gm.get_submodule(modname)
1344-
1345-
# Shard weight using the unified function (also updates the parameter)
1346-
original_weight = gm.get_parameter(weight_key)
1347-
1348-
_, weight_new_shape = shard_weight_tensor(
1349-
gm=gm,
1350-
weight_tensor=original_weight,
1351-
param_key=weight_key,
1352-
dim=dim,
1353-
rank=rank,
1354-
world_size=world_size,
1355-
min_local_shape=min_local_shape,
1356-
fused_weight_dims=fused_weight_dims,
1357-
)
1358-
1359-
if bias_key is not None and dim == 0:
1360-
# update bias for dim 0 --> we can handle it like the weight
1361-
original_bias = gm.get_parameter(bias_key)
1362-
shard_weight_tensor(
1331+
# num_users = num_users_of_weight_node(node)
1332+
# if num_users > 1 or num_users == 0:
1333+
# ad_logger.warning(
1334+
# f"Weight node {node} has {num_users} users. This is not supported for sharding. Skipping."
1335+
# )
1336+
# return
1337+
# # get weight and bias key
1338+
# weight_key, bias_key = extract_param_names_from_node(node)
1339+
1340+
# modname = weight_key.rpartition(".")[0]
1341+
# submod = gm.get_submodule(modname)
1342+
1343+
# # Shard weight using the unified function (also updates the parameter)
1344+
# original_weight = gm.get_parameter(weight_key)
1345+
weight_nodes = extract_weight_nodes(node)
1346+
for weight_node, bias_node in weight_nodes:
1347+
_, weight_new_shape = shard_weight_tensor(
13631348
gm=gm,
1364-
weight_tensor=original_bias,
1365-
param_key=bias_key,
1349+
weight_tensor=weight_node.node,
1350+
param_key=weight_node.node_key,
13661351
dim=dim,
13671352
rank=rank,
13681353
world_size=world_size,
13691354
min_local_shape=min_local_shape,
13701355
fused_weight_dims=fused_weight_dims,
13711356
)
1372-
elif bias_key is not None and rank != world_size - 1:
1373-
# update the bias for dim 1 --> in this case only the last rank gets the bias to avoid
1374-
# double counting it. For all other we will delete the bias.
1375-
args = list(node.args)
1376-
node_bias = args[2]
1377-
args[2] = None
1378-
node.args = tuple(args)
1379-
gm.graph.erase_node(node_bias)
1380-
bias_param_name = bias_key.rpartition(".")[-1]
1381-
setattr(submod, bias_param_name, None)
1382-
gm._register_load_state_dict_pre_hook(partial(_load_hook_remove, param_key=bias_key))
1383-
1384-
if quantization_cb is not None:
1385-
quantization_cb(
1386-
gm=gm,
1387-
submod=submod,
1388-
node=node,
1389-
weight_key=weight_key,
1390-
weight_new_shape=weight_new_shape,
1391-
dim=dim,
1392-
rank=rank,
1393-
world_size=world_size,
1394-
)
1357+
1358+
if bias_node is not None and dim == 0:
1359+
# update bias for dim 0 --> we can handle it like the weight
1360+
shard_weight_tensor(
1361+
gm=gm,
1362+
weight_tensor=bias_node.node,
1363+
param_key=bias_node.node_key,
1364+
dim=dim,
1365+
rank=rank,
1366+
world_size=world_size,
1367+
min_local_shape=min_local_shape,
1368+
fused_weight_dims=fused_weight_dims,
1369+
)
1370+
elif bias_node is not None and rank != world_size - 1:
1371+
# update the bias for dim 1 --> in this case only the last rank gets the bias to avoid
1372+
# double counting it. For all other we will delete the bias.
1373+
args = list(node.args)
1374+
node_bias = args[2]
1375+
args[2] = None
1376+
node.args = tuple(args)
1377+
gm.graph.erase_node(node_bias)
1378+
bias_param_name = bias_node.node_key.rpartition(".")[-1]
1379+
setattr(bias_node.submod, bias_param_name, None)
1380+
gm._register_load_state_dict_pre_hook(
1381+
partial(_load_hook_remove, param_key=bias_node.node_key)
1382+
)
1383+
1384+
if quantization_cb is not None:
1385+
quantization_cb(
1386+
gm=gm,
1387+
submod=weight_node.submod,
1388+
node=node,
1389+
weight_key=weight_node.node_key,
1390+
weight_new_shape=weight_new_shape,
1391+
dim=dim,
1392+
rank=rank,
1393+
world_size=world_size,
1394+
)
13951395

13961396
# # # column shard with no gather: the output is sharded
13971397
if not add_dist:
@@ -2251,7 +2251,7 @@ def detect_sharding_from_config(
22512251

22522252
for lin_node in linear_nodes:
22532253
# use node's weight name to get the module name
2254-
module_name = extract_weight_node(lin_node).target
2254+
module_name = extract_weight_nodes(lin_node)[0].target
22552255

22562256
if any(attn_name in module_name for attn_name in attn_names):
22572257
# find the next attention node and infer the head_dim

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

Lines changed: 57 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010
from pydantic import BaseModel, ConfigDict
11+
from torch import nn
1112
from torch._ops import OpOverload, OpOverloadPacket
1213
from torch.fx import GraphModule, Node
1314

@@ -51,6 +52,13 @@ class LayerSubgraph(BaseModel):
5152
min_local_shape: int = 1
5253

5354

55+
class WeightNode(BaseModel):
56+
model_config = ConfigDict(arbitrary_types_allowed=True)
57+
node: Node
58+
node_key: str
59+
submod: nn.Module
60+
61+
5462
@dataclass
5563
class modelopt_quant_params:
5664
input_node: torch.fx.node.Node = None
@@ -129,10 +137,12 @@ def get_quantization_params_from_linear_node(linear_op: torch.fx.node.Node):
129137
return input_params, weight_params, output_params
130138

131139

132-
def extract_weight_node(node: Node) -> int:
133-
"""Extracts the weight node from the given parametrized node"""
140+
def extract_weight_nodes(node: Node) -> Tuple[List[WeightNode], List[WeightNode]]:
141+
"""Extracts the list of weight node and optional bias node from the given parametrized node"""
134142
gm = node.graph.owning_module
135-
param_names = {name for name, _ in gm.named_parameters()}
143+
param_names = {name for name, _ in gm.named_parameters()}.union(
144+
{name for name, _ in gm.named_buffers()}
145+
)
136146

137147
def find_get_attr_node(weight_node: Node) -> Node:
138148
"""Recursively traverse inputs of allowed nodes to find a node with 'get_attr' op."""
@@ -157,55 +167,68 @@ def find_get_attr_node(weight_node: Node) -> Node:
157167
return None
158168

159169
if is_op(node, torch.ops.aten.bmm):
160-
weight_node = node.args[1]
170+
# no bias for bmm
171+
return [WeightNode(node=node.args[1], node_key=node.args[1].target)], []
161172
# for other parametrized nodes, we need to find the weight node
162173
else:
163-
weight_nodes = [
174+
all_weight_nodes = [
164175
n for n in node.args if isinstance(n, Node) and find_get_attr_node(n) is not None
165176
]
166-
# can be two weights (if bias weight is present)
167-
weight_node = None
168-
if weight_nodes:
169-
weight_node = weight_nodes[0]
170-
# for modelopt quantized graph, there will be a quantize_op
171-
_, weight_params, _ = get_quantization_params_from_linear_node(node)
172-
weight_node = weight_params.input_node if weight_params else weight_node
173-
assert weight_node is not None, "Expected at least one weight node in the parametrized node"
174-
return find_get_attr_node(weight_node)
177+
# separate weight nodes and bias nodes
178+
weight_nodes = [n for n in all_weight_nodes if n.target.endswith("weight")]
179+
bias_nodes = [n for n in all_weight_nodes if n.target.endswith("bias")]
180+
weight_nodes = [
181+
WeightNode(
182+
node=n, node_key=n.target, submod=gm.get_submodule(n.target.rpartition(".")[0])
183+
)
184+
for n in weight_nodes
185+
]
186+
bias_nodes = [
187+
WeightNode(
188+
node=n, node_key=n.target, submod=gm.get_submodule(n.target.rpartition(".")[0])
189+
)
190+
for n in bias_nodes
191+
]
192+
return weight_nodes, bias_nodes
175193

176194

177195
def num_users_of_weight_node(node: Node) -> int:
178196
"""Returns the number of users of the weight node of the given parametrized node."""
179-
weight_node = extract_weight_node(node)
197+
weight_node = extract_weight_nodes(node)[0]
180198
return len(weight_node.users) if weight_node is not None else 0
181199

182200

183-
def extract_param_names_from_node(node: Node) -> Tuple[str, Optional[str]]:
201+
def extract_param_names_from_node(node: Node) -> Tuple[List[str], Optional[List[str]]]:
184202
"""Extracts the name of the parameter associated with the given parametrized node.
185203
186204
Args:
187205
node: node with weight parameters in the graph.
188206
"""
189-
weight_node = extract_weight_node(node)
207+
# try:
190208

191-
assert weight_node, "Cannot identify weight parameter of linear node."
209+
# except:
210+
# a = 1
192211

193-
# Map arg to named parameter
194-
weight_name = weight_node.target
212+
# assert weight_node, "Cannot identify weight parameter of linear node."
195213

196-
# check for bias
197-
if is_op(node, torch.ops.aten.bmm):
198-
bias_node = node.args[2] if len(node.args) > 2 else None
199-
else:
200-
weight_nodes = [n for n in node.args if isinstance(n, Node) and n.op == "get_attr"]
201-
if len(weight_nodes) > 1:
202-
bias_node = weight_nodes[1]
203-
else:
204-
bias_node = None
205-
assert bias_node is None or bias_node.op == "get_attr"
206-
bias_name = bias_node.target if bias_node is not None else None
214+
# # Map arg to named parameter
215+
# weight_name = weight_node.target
216+
217+
# # check for bias
218+
# if is_op(node, torch.ops.aten.bmm):
219+
# bias_node = node.args[2] if len(node.args) > 2 else None
220+
# else:
221+
# weight_nodes = [n for n in node.args if isinstance(n, Node) and n.op == "get_attr"]
222+
# if len(weight_nodes) > 1:
223+
# bias_node = weight_nodes[1]
224+
# else:
225+
# bias_node = None
226+
# assert bias_node is None or bias_node.op == "get_attr"
227+
# bias_name = bias_node.target if bias_node is not None else None
207228

208-
return weight_name, bias_name
229+
# return weight_name, bias_name
230+
weight_nodes, bias_nodes = extract_weight_nodes(node)
231+
return [n.node_key for n in weight_nodes], [n.node_key for n in bias_nodes]
209232

210233

211234
def get_op_overload_packet(node: Union[OpOverloadPacket, OpOverload]) -> OpOverloadPacket:
@@ -751,9 +774,9 @@ def get_weight_shape(
751774
if not is_any_lin_op(node):
752775
return None
753776
if dim is None:
754-
return shape(extract_weight_node(node))
777+
return shape(extract_weight_nodes(node)[0])
755778
else:
756-
return shape(extract_weight_node(node))[dim]
779+
return shape(extract_weight_nodes(node)[0])[dim]
757780

758781

759782
def get_layer_after_linear_node(

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ def should_skip_quantization(
117117
else:
118118
if not (is_linear_op(node_or_name) or is_bmm_op(node_or_name)):
119119
return True
120-
param_name, _ = extract_param_names_from_node(node_or_name)
121-
modname, _, _ = param_name.rpartition(".")
120+
param_names, _ = extract_param_names_from_node(node_or_name)
121+
modname, _, _ = param_names[0].rpartition(".")
122122

123123
return any(fnmatch(modname, pattern) for pattern in excluded_patterns)
124124

0 commit comments

Comments
 (0)