File tree Expand file tree Collapse file tree 2 files changed +21
-3
lines changed
tensorrt_llm/_torch/auto_deploy Expand file tree Collapse file tree 2 files changed +21
-3
lines changed Original file line number Diff line number Diff line change @@ -292,7 +292,7 @@ def detect_sharding_from_factory_config(
292292 num_simple_shards = 0
293293 num_row_col_shards = 0
294294
295- for lin_node in filtered_nodes (gm .graph .nodes , is_linear_op ):
295+ for lin_node in filtered_nodes (gm .graph .nodes , [ is_linear_op , is_fake_quantized_linear_op ] ):
296296 # use node's weight name to get the module name
297297 module_name = lin_node .args [1 ].target
298298
@@ -368,7 +368,7 @@ def detect_sharding_from_factory_config(
368368 )
369369 num_row_col_shards += 1
370370 else :
371- ad_logger .warning ("Invalid sharding config. Skipping." )
371+ ad_logger .warning (f"Unsupported sharding action { config } . Skipping." )
372372 else :
373373 # TODO: local refers to hybrid EP+TP parallelism. Not supported yet.
374374 ad_logger .warning ("Local EP+TP sharding is not supported yet. Skipping." )
@@ -387,7 +387,19 @@ def detect_sharding_from_factory_config(
387387 )
388388 num_simple_shards += 1
389389 else :
390- ad_logger .warning ("Invalid sharding config. Skipping." )
390+ ad_logger .warning (
391+ f"Unsupported sharding action { config } . Fallback to simple shard"
392+ )
393+ sharding_config .tp_transforms .append (
394+ TPShardingInfo .from_node (
395+ lin_node ,
396+ split_dim = SplitDimension .COLUMN ,
397+ rank = rank ,
398+ world_size = world_size ,
399+ dist_op = "all_gather" ,
400+ min_local_shape = 1 ,
401+ )
402+ )
391403 # after successful match, break the loop
392404 break
393405
Original file line number Diff line number Diff line change @@ -239,6 +239,12 @@ def filtered_nodes(
239239 for node in nodes :
240240 if target (node ):
241241 yield node
242+ elif isinstance (target , Iterable ) and all (isinstance (t , Callable ) for t in target ):
243+ for node in nodes :
244+ for t in target :
245+ if t (node ):
246+ yield node
247+ break
242248 else :
243249 # Handle the case where target or ops contains operations
244250 operations = ops if ops is not None else target
You can’t perform that action at this time.
0 commit comments