2121)
2222from tensorrt_llm ._torch .auto_deploy .utils .node_utils import is_linear_op , is_op
2323
24+ base_model_tp_plan = {
25+ "q_proj" : "colwise" ,
26+ "k_proj" : "colwise" ,
27+ "v_proj" : "colwise" ,
28+ "o_proj" : "rowwise" ,
29+ "gate_proj" : "colwise" ,
30+ "up_proj" : "colwise" ,
31+ "down_proj" : "rowwise" ,
32+ "linear1" : "colwise" ,
33+ "linear2" : "rowwise" ,
34+ "linear" : "gather" ,
35+ "input_layernorm.weight" : "sequence_parallel" ,
36+ "post_attention_layernorm.weight" : "sequence_parallel" ,
37+ "norm.weight" : "sequence_parallel" ,
38+ "shared_expert.gate_proj" : "local_colwise" ,
39+ "shared_expert.up_proj" : "local_colwise" ,
40+ "shared_expert.down_proj" : "local_rowwise" ,
41+ "experts.gate_up_proj" : "local_packed_rowwise" ,
42+ "experts.down_proj" : "local_colwise" ,
43+ "experts" : "local" ,
44+ "feed_forward" : "gather" ,
45+ "self" : "gather" ,
46+ }
47+
48+ predefined_config = {
49+ "head_dim" : 8 ,
50+ "tp_plan" : base_model_tp_plan ,
51+ }
52+
2453
2554class GQA_Block (nn .Module ):
2655 def __init__ (
@@ -83,6 +112,7 @@ def _run_job(
83112 model_cls : nn .Module ,
84113 dist_op_expected : str ,
85114 bias : bool ,
115+ from_config : bool ,
86116 rank : int ,
87117 world_size : int ,
88118) -> None :
@@ -129,6 +159,7 @@ def _get_expected_num_params(num_p_og: int) -> int:
129159 num_params = W_q_local_size + W_k_local_size + W_v_local_size + W_o_local_size
130160 else :
131161 num_params = num_p_og // world_size + num_update
162+ print (f"\n \n num_p_og: { num_p_og } , num_params: { num_params } " )
132163 return num_params
133164
134165 def verify_local_weight_sizes (gm ) -> bool :
@@ -147,8 +178,12 @@ def verify_local_weight_sizes(gm) -> bool:
147178 op_expected = getattr (torch .ops .auto_deploy , dist_op_expected )
148179
149180 def transform_func (gm ) -> None :
150- sharding_config = ShardingConfig ()
151- detect_column_row_shard (gm , rank , world_size , sharding_config )
181+ sharding_config = ShardingConfig (rank = rank , world_size = world_size )
182+ if from_config :
183+ if world_size > 1 :
184+ sharding_config .create_sharding_from_config (gm , predefined_config )
185+ else :
186+ detect_column_row_shard (gm , rank , world_size , sharding_config )
152187 sharding_transform_executor (gm , sharding_config )
153188
154189 def combined_graph_check (gm ) -> bool :
@@ -174,6 +209,7 @@ def _run_pattern_detection_job(
174209 bias : bool ,
175210 rank : int ,
176211 world_size : int ,
212+ from_config : bool ,
177213) -> None :
178214 # init model and input
179215 batch_size = 4
@@ -200,7 +236,7 @@ def _run_pattern_detection_job(
200236 gm = torch_export_to_gm (model , args = (x ,), clone = True )
201237 expected_transformations = []
202238 # if world_size == 1, no sharding transformations should be detected
203- if world_size > 1 :
239+ if world_size > 1 or from_config :
204240 if model_cls == GQA_Block :
205241 min_local_shape = num_features // num_heads
206242 for node in gm .graph .nodes :
@@ -262,8 +298,11 @@ def _run_pattern_detection_job(
262298 )
263299
264300 # get detected transformations
265- sharding_config = ShardingConfig ()
266- detect_column_row_shard (gm , rank , world_size , sharding_config )
301+ sharding_config = ShardingConfig (rank = rank , world_size = world_size )
302+ if from_config :
303+ sharding_config .create_sharding_from_config (gm , predefined_config )
304+ else :
305+ detect_column_row_shard (gm , rank , world_size , sharding_config )
267306 detected_transformations = sharding_config .tp_transforms
268307
269308 # Run pattern detection test
@@ -272,6 +311,7 @@ def _run_pattern_detection_job(
272311
273312@pytest .mark .parametrize ("device_count" , get_device_counts ())
274313@pytest .mark .parametrize ("bias" , [False , True ])
314+ @pytest .mark .parametrize ("from_config" , [False , True ])
275315@pytest .mark .parametrize (
276316 "model_cls, dist_op_expected" ,
277317 (
@@ -280,15 +320,22 @@ def _run_pattern_detection_job(
280320 (GQA_Block , "torch_dist_all_reduce" ),
281321 ),
282322)
283- def test_sharding (model_cls : Type [nn .Module ], dist_op_expected : str , bias : bool , device_count : int ):
323+ def test_sharding (
324+ model_cls : Type [nn .Module ],
325+ dist_op_expected : str ,
326+ bias : bool ,
327+ device_count : int ,
328+ from_config : bool ,
329+ ):
284330 dist_common .spawn_multiprocess_job (
285- job = partial (_run_job , model_cls , dist_op_expected , bias ),
331+ job = partial (_run_job , model_cls , dist_op_expected , bias , from_config ),
286332 size = device_count ,
287333 )
288334
289335
290336@pytest .mark .parametrize ("world_size" , [1 , 8 ])
291337@pytest .mark .parametrize ("bias" , [False , True ])
338+ @pytest .mark .parametrize ("from_config" , [False , True ])
292339@pytest .mark .parametrize (
293340 "model_cls, dist_op_expected" ,
294341 (
@@ -298,11 +345,15 @@ def test_sharding(model_cls: Type[nn.Module], dist_op_expected: str, bias: bool,
298345 ),
299346)
300347def test_sharding_pattern_detection (
301- model_cls : Type [nn .Module ], dist_op_expected : str , bias : bool , world_size : int
348+ model_cls : Type [nn .Module ],
349+ dist_op_expected : str ,
350+ bias : bool ,
351+ world_size : int ,
352+ from_config : bool ,
302353):
303354 """Test pattern detection logic without distributed execution.
304355
305356 This test verifies only the pattern detection logic with provided world_size.
306357 No need to run distributed job, can be run on single process.
307358 """
308- _run_pattern_detection_job (model_cls , bias , 0 , world_size )
359+ _run_pattern_detection_job (model_cls , bias , 0 , world_size , from_config )
0 commit comments