3838 LayerSubgraph ,
3939 LayerType ,
4040 bfs ,
41+ extract_weight_name ,
4142 extract_weight_nodes ,
4243 filtered_nodes ,
4344 get_all_layer_subgraphs ,
@@ -1272,10 +1273,6 @@ def split_fused_tensor(
12721273 fused_dims : list = fused_weight_dims ,
12731274 d : int = dim ,
12741275 ) -> torch .Tensor :
1275- # dim_d = t.shape[d]
1276- # num_parts = 1
1277- # part_size = dim_d // num_parts
1278- # fused_dims = [part_size] * num_parts
12791276 return torch .cat (
12801277 [split_tensor (w ) for w in torch .split (t , fused_dims , dim = d )],
12811278 dim = d ,
@@ -1343,23 +1340,35 @@ def _shard_parameter_node(
13431340 # # Shard weight using the unified function (also updates the parameter)
13441341 # original_weight = gm.get_parameter(weight_key)
13451342 weight_nodes = extract_weight_nodes (node )
1346- for weight_node , bias_node in weight_nodes :
1343+ for weight_node in weight_nodes . weights :
13471344 _ , weight_new_shape = shard_weight_tensor (
13481345 gm = gm ,
1349- weight_tensor = weight_node .node ,
1346+ weight_tensor = weight_node .tensor ,
13501347 param_key = weight_node .node_key ,
13511348 dim = dim ,
13521349 rank = rank ,
13531350 world_size = world_size ,
13541351 min_local_shape = min_local_shape ,
13551352 fused_weight_dims = fused_weight_dims ,
13561353 )
1354+ if quantization_cb is not None :
1355+ quantization_cb (
1356+ gm = gm ,
1357+ submod = weight_node .submod ,
1358+ node = node ,
1359+ weight_key = weight_node .node_key ,
1360+ weight_new_shape = weight_new_shape ,
1361+ dim = dim ,
1362+ rank = rank ,
1363+ world_size = world_size ,
1364+ )
13571365
1358- if bias_node is not None and dim == 0 :
1366+ for bias_node in weight_nodes .biases :
1367+ if dim == 0 :
13591368 # update bias for dim 0 --> we can handle it like the weight
13601369 shard_weight_tensor (
13611370 gm = gm ,
1362- weight_tensor = bias_node .node ,
1371+ weight_tensor = bias_node .tensor ,
13631372 param_key = bias_node .node_key ,
13641373 dim = dim ,
13651374 rank = rank ,
@@ -1381,18 +1390,6 @@ def _shard_parameter_node(
13811390 partial (_load_hook_remove , param_key = bias_node .node_key )
13821391 )
13831392
1384- if quantization_cb is not None :
1385- quantization_cb (
1386- gm = gm ,
1387- submod = weight_node .submod ,
1388- node = node ,
1389- weight_key = weight_node .node_key ,
1390- weight_new_shape = weight_new_shape ,
1391- dim = dim ,
1392- rank = rank ,
1393- world_size = world_size ,
1394- )
1395-
13961393 # # # column shard with no gather: the output is sharded
13971394 if not add_dist :
13981395 return
@@ -1423,107 +1420,6 @@ def _update_node_args(node: Node, args: tuple) -> None:
14231420 )
14241421
14251422
1426- def _insert_sharded_moe_stacked (
1427- gm : GraphModule ,
1428- node : Node ,
1429- rank : int ,
1430- world_size : int ,
1431- allreduce_strategy : AllReduceStrategy ,
1432- scale_names : Sequence [str ] = (),
1433- ):
1434- """Update the torch_moe node with sliced stacked weight tensors,
1435- sharded `selected_experts` and `final_scales(router_logics)`.
1436- Add an all_reduce node after the moe node.
1437-
1438- For torch_moe with stacked tensor format (single-element lists containing 3D tensors).
1439-
1440- NOTE: allreduce_strategy is MANDATORY and must be explicitly provided.
1441- """
1442- if allreduce_strategy is None :
1443- raise ValueError (f"allreduce_strategy must be set for MoE sharding on node { node .name } " )
1444-
1445- # Extract the stacked tensors from single-element lists
1446- # args[3] = w1_weight (Node representing list with one 3D tensor, or direct list)
1447- # args[4] = w2_weight (Node representing list with one 3D tensor, or direct list)
1448-
1449- # Helper to extract tensor node from list (handles both Node and direct list)
1450- def extract_tensor_from_list_arg (list_arg ):
1451- if isinstance (list_arg , Node ) and list_arg .target is list :
1452- # It's a list() call node - extract from its args
1453- return list_arg .args [0 ][0 ] # args[0] is the list content, [0] is first element
1454- elif isinstance (list_arg , (list , tuple )):
1455- # Direct list
1456- return list_arg [0 ]
1457- else :
1458- raise ValueError (f"Unexpected list format: { type (list_arg )} " )
1459-
1460- w3_w1_tensor_node = extract_tensor_from_list_arg (node .args [3 ])
1461- w2_tensor_node = extract_tensor_from_list_arg (node .args [4 ])
1462- num_experts = _get_dim0_from_arg (gm , w3_w1_tensor_node )
1463-
1464- args = list (node .args )
1465-
1466- # -- Handle selected_experts and final_scales sharding --
1467- selected_experts = args [1 ]
1468- final_scales = args [2 ]
1469-
1470- experts_per_rank = num_experts // world_size
1471-
1472- with gm .graph .inserting_before (node ):
1473- lower = experts_per_rank * rank
1474- # selected_experts_local = selected_experts - low
1475- selected_experts_local = gm .graph .create_node (
1476- "call_function" , operator .sub , args = (selected_experts , lower ), kwargs = {}
1477- )
1478-
1479- # For num_experts % world_size != 0 case,
1480- # assign the last (num_experts % world_size) experts to the last rank
1481- div_node = gm .graph .create_node (
1482- "call_function" , operator .floordiv , args = (selected_experts , experts_per_rank ), kwargs = {}
1483- )
1484-
1485- comp_op = torch .ge if rank == world_size - 1 else torch .eq
1486- rank_mask = gm .graph .create_node ("call_function" , comp_op , args = (div_node , rank ), kwargs = {})
1487-
1488- # final_scales_local = final_scales * rank_mask
1489- final_scales_local = gm .graph .create_node (
1490- "call_function" , operator .mul , args = (final_scales , rank_mask ), kwargs = {}
1491- )
1492-
1493- # -- Transform expert weight parameters --
1494- local_lo , local_hi = _split_range_last_remainder (num_experts , world_size , rank )
1495-
1496- # Transform w3_w1_stacked: slice experts, swap [W1,W3]->[W3,W1], transpose (E,H,2I)->(E,2I,H)
1497- if isinstance (w3_w1_tensor_node , Node ):
1498- _transform_bmm_moe_weight_param (
1499- gm , w3_w1_tensor_node , local_lo , local_hi , swap_gate_up = True
1500- )
1501-
1502- # Transform w2_stacked: slice experts, transpose (E,I,H)->(E,H,I)
1503- if isinstance (w2_tensor_node , Node ):
1504- _transform_bmm_moe_weight_param (gm , w2_tensor_node , local_lo , local_hi , swap_gate_up = False )
1505-
1506- # -- Update args (keep same lists/nodes, just with transformed parameters) --
1507- args [1 ] = selected_experts_local
1508- args [2 ] = final_scales_local
1509- # args[3] and args[4] stay the same - we modified the parameters in-place
1510-
1511- ad_logger .debug (
1512- f"Updated node { node } : replaced original arguments { node .args } with sharded arguments { args } ."
1513- )
1514-
1515- node .args = tuple (args )
1516-
1517- # -- add an all_reduce node --
1518- with gm .graph .inserting_after (node ):
1519- dist_node = gm .graph .call_function (
1520- torch .ops .auto_deploy .torch_dist_all_reduce .default ,
1521- args = (node , allreduce_strategy ),
1522- )
1523- node .replace_all_uses_with (dist_node )
1524- dist_node .replace_input_with (dist_node , node )
1525-
1526-
15271423def _insert_sharded_moe (
15281424 gm : GraphModule ,
15291425 node : Node ,
@@ -2251,9 +2147,9 @@ def detect_sharding_from_config(
22512147
22522148 for lin_node in linear_nodes :
22532149 # use node's weight name to get the module name
2254- module_name = extract_weight_nodes (lin_node )[ 0 ]. target
2150+ weight_name = extract_weight_name (lin_node )
22552151
2256- if any (attn_name in module_name for attn_name in attn_names ):
2152+ if any (attn_name in weight_name for attn_name in attn_names ):
22572153 # find the next attention node and infer the head_dim
22582154 next_attention_node , _ = bfs (
22592155 lin_node , is_any_attention_op , attr_next = "users" , include_root = False
@@ -2277,7 +2173,7 @@ def detect_sharding_from_config(
22772173 # Then we escape dots, and finally we replace @ with .*
22782174 pattern_string = pattern_string .replace ("*" , "@" )
22792175 pattern_regex = re .escape (pattern_string ).replace ("@" , ".*" )
2280- if re .match (pattern_regex , module_name ):
2176+ if re .match (pattern_regex , weight_name ):
22812177 # we have a match. Get the config for this layer
22822178 config = tp_plan [key ]
22832179
@@ -2316,7 +2212,7 @@ def detect_sharding_from_config(
23162212 elif "local" in config :
23172213 # Check if this applies to shared experts in EP parallelism.
23182214 # If yes, apply the TP col-row shard.
2319- if "shared" in module_name :
2215+ if "shared" in weight_name :
23202216 col_row_action = config .replace ("local_" , "" )
23212217 if col_row_action == "colwise" :
23222218 transform_container .add (
0 commit comments