4040 LayerType ,
4141 bfs ,
4242 extract_weight_name ,
43- extract_weight_nodes ,
4443 filtered_nodes ,
4544 get_all_layer_subgraphs ,
45+ get_all_weight_infos ,
4646 get_all_weights_in_subgraph ,
47- get_layer_after_linear_node ,
4847 is_any_attention_op ,
4948 is_any_lin_op ,
5049 is_any_moe_op ,
@@ -1296,6 +1295,11 @@ def _shard_parameter_node(
12961295 rank , world_size = config .rank , config .world_size
12971296 allreduce_strategy = config .allreduce_strategy .name
12981297
1298+ if "sharded" in node .meta and node .meta ["sharded" ]:
1299+ # Node was already sharded, skip
1300+ return
1301+ node .meta ["sharded" ] = True
1302+
12991303 num_users = num_users_of_weight_node (node )
13001304 if num_users > 1 or num_users == 0 :
13011305 ad_logger .warning (
@@ -1304,12 +1308,17 @@ def _shard_parameter_node(
13041308 return
13051309
13061310 # Shard weight using the unified function (also updates the parameter)
1307- weight_nodes = extract_weight_nodes (node )
1308- for weight_node in weight_nodes .weights :
1311+ all_weight_infos = get_all_weight_infos (node )
1312+ # Parametrized nodes must have at least one weight (for debugging)
1313+ assert len (all_weight_infos .weights ) > 0 , (
1314+ f"Node { node .name } has no weights - weight mapping may be incorrect"
1315+ )
1316+
1317+ for weight_info in all_weight_infos .weights :
13091318 _ , weight_new_shape = shard_weight_tensor (
13101319 gm = gm ,
1311- weight_tensor = weight_node .tensor ,
1312- param_key = weight_node .node_key ,
1320+ weight_tensor = weight_info .tensor ,
1321+ param_key = weight_info .node_key ,
13131322 dim = dim ,
13141323 rank = rank ,
13151324 world_size = world_size ,
@@ -1319,40 +1328,40 @@ def _shard_parameter_node(
13191328 if quantization_cb is not None :
13201329 quantization_cb (
13211330 gm = gm ,
1322- submod = weight_node .submod ,
1331+ submod = weight_info .submod ,
13231332 node = node ,
1324- weight_key = weight_node .node_key ,
1333+ weight_key = weight_info .node_key ,
13251334 weight_new_shape = weight_new_shape ,
13261335 dim = dim ,
13271336 rank = rank ,
13281337 world_size = world_size ,
13291338 )
13301339
1331- for bias_node in weight_nodes .biases :
1340+ for bias_info in all_weight_infos .biases :
13321341 if dim == 0 :
13331342 # update bias for dim 0 --> we can handle it like the weight
13341343 shard_weight_tensor (
13351344 gm = gm ,
1336- weight_tensor = bias_node .tensor ,
1337- param_key = bias_node .node_key ,
1345+ weight_tensor = bias_info .tensor ,
1346+ param_key = bias_info .node_key ,
13381347 dim = dim ,
13391348 rank = rank ,
13401349 world_size = world_size ,
13411350 min_local_shape = min_local_shape ,
13421351 fused_weight_dims = fused_weight_dims ,
13431352 )
1344- elif bias_node is not None and rank != world_size - 1 :
1353+ elif rank != world_size - 1 :
13451354 # update the bias for dim 1 --> in this case only the last rank gets the bias to avoid
13461355 # double counting it. For all other we will delete the bias.
13471356 args = list (node .args )
13481357 node_bias = args [2 ]
13491358 args [2 ] = None
13501359 node .args = tuple (args )
13511360 gm .graph .erase_node (node_bias )
1352- bias_param_name = bias_node .node_key .rpartition ("." )[- 1 ]
1353- setattr (bias_node .submod , bias_param_name , None )
1361+ bias_param_name = bias_info .node_key .rpartition ("." )[- 1 ]
1362+ setattr (bias_info .submod , bias_param_name , None )
13541363 gm ._register_load_state_dict_pre_hook (
1355- partial (_load_hook_remove , param_key = bias_node .node_key )
1364+ partial (_load_hook_remove , param_key = bias_info .node_key )
13561365 )
13571366
13581367 # # # column shard with no gather: the output is sharded
@@ -2295,47 +2304,37 @@ def detect_sharding_from_config(
22952304 raise ValueError (f"Unsupported sharding source: { source } " )
22962305 tp_plan = config ["tp_plan" ]
22972306
2298- # If the node is inside the attention module, we need to set min_local_shape to the
2299- # head_dim - otherwise, we would risk splitting the heads into smaller shards.
2300- # TODO: is there a better way to check if we are in attention module?
2301- attn_names = [
2302- "attention" ,
2303- "Attention" ,
2304- "attn" ,
2305- "Attn" ,
2306- "q_proj" ,
2307- "k_proj" ,
2308- "v_proj" ,
2309- "o_proj" ,
2310- ]
2311-
23122307 num_shards = 0
23132308 num_simple_shards = 0
23142309 num_row_col_shards = 0
23152310 num_attention_shards = 0
23162311 num_ssm_shards = 0
2317- head_dim = - 1
23182312 linear_nodes = list (filtered_nodes (gm .graph .nodes , is_any_lin_op ))
23192313
2314+ # use layer_subgraphs to determine the layer_type
2315+ # and check the validity of the sharding transform
2316+ layer_subgraphs , unprocessed_linear_nodes = get_all_layer_subgraphs (gm )
2317+
23202318 for lin_node in linear_nodes :
23212319 # use node's weight name to get the module name
23222320 weight_name = extract_weight_name (lin_node )
2323-
2324- if any (attn_name in weight_name for attn_name in attn_names ):
2325- # find the next attention node and infer the head_dim
2326- next_attention_node , _ = bfs (
2327- lin_node , is_any_attention_op , attr_next = "users" , include_root = False
2328- )
2329- if next_attention_node is None :
2330- # this is the last attention node in the graph. Take the previously found head_dim
2331- assert head_dim != - 1 , "Head dim not found for the last attention node"
2332- else :
2333- head_dim = shape (next_attention_node )[- 1 ]
2334- min_local_shape = head_dim
2335- layer_type = LayerType .ATTENTION
2321+ # get the parent layer_subgraph
2322+ layer_subgraph = [
2323+ layer
2324+ for layer in layer_subgraphs
2325+ if lin_node in layer .opening_nodes or lin_node == layer .terminating_node
2326+ ]
2327+ if len (layer_subgraph ) == 1 :
2328+ layer_subgraph = layer_subgraph [0 ]
2329+ layer_type = layer_subgraph .layer_type
23362330 else :
2337- min_local_shape = 1
2338- layer_type = LayerType .MLP
2331+ if lin_node in unprocessed_linear_nodes :
2332+ layer_type = LayerType .UNKNOWN
2333+ else :
2334+ ad_logger .warning (
2335+ f"Failed to find the parent layer_subgraph for linear node { lin_node } . "
2336+ f"May result in incorrect sharding."
2337+ )
23392338
23402339 # use regex to find if module_name matches any of the keys in sharding_config
23412340 for key in tp_plan .keys ():
@@ -2349,11 +2348,6 @@ def detect_sharding_from_config(
23492348 # we have a match. Get the config for this layer
23502349 config = tp_plan [key ]
23512350
2352- if config in ["colwise" , "mamba" ]:
2353- cur_node_index = linear_nodes .index (lin_node )
2354- layer_subgraph = get_layer_after_linear_node (
2355- linear_nodes , [cur_node_index - 1 ], enforce_strict_linear_history = False
2356- )
23572351 if config == "colwise" :
23582352 _process_column_sharding (
23592353 layer_subgraph = layer_subgraph ,
@@ -2366,7 +2360,6 @@ def detect_sharding_from_config(
23662360 split_dim = SplitDimension .ROW ,
23672361 config = transform_container .config ,
23682362 dist_op = "all_reduce" ,
2369- min_local_shape = min_local_shape ,
23702363 layer_type = layer_type ,
23712364 )
23722365 ):
@@ -2393,7 +2386,6 @@ def detect_sharding_from_config(
23932386 split_dim = SplitDimension .COLUMN ,
23942387 config = transform_container .config ,
23952388 dist_op = None ,
2396- min_local_shape = min_local_shape ,
23972389 layer_type = layer_type ,
23982390 )
23992391 )
@@ -2404,7 +2396,6 @@ def detect_sharding_from_config(
24042396 split_dim = SplitDimension .ROW ,
24052397 config = transform_container .config ,
24062398 dist_op = "all_reduce" ,
2407- min_local_shape = min_local_shape ,
24082399 layer_type = layer_type ,
24092400 )
24102401 ):
@@ -2423,7 +2414,6 @@ def detect_sharding_from_config(
24232414 split_dim = SplitDimension .COLUMN ,
24242415 config = transform_container .config ,
24252416 dist_op = "all_gather" ,
2426- min_local_shape = 1 ,
24272417 layer_type = layer_type ,
24282418 )
24292419 ):
@@ -2536,7 +2526,7 @@ def detect_column_row_shard(
25362526 attention_nodes = list (filtered_nodes (layer_subgraph , is_any_attention_op ))
25372527 min_local_shape = 1
25382528
2539- if config .simple_shard_only :
2529+ if config .simple_shard_only or layer . layer_type == LayerType . UNKNOWN :
25402530 ad_logger .debug (
25412531 f"Forcing Simple Shard on nodes: { nodes_linear } with layer type: { layer .layer_type } "
25422532 )
0 commit comments