Skip to content

Commit 6149ede

Browse files
fixed SSM sharding test
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
1 parent ce661c8 commit 6149ede

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,9 @@ def _run_pattern_detection_job(
572572
fused_weight_dims=None,
573573
)
574574
)
575-
if len(node.args) > 1 and "norm_weight" in node.args[0].name:
575+
if len(node.args) > 1 and (
576+
"norm_weight" in node.args[0].name or "a_log" in node.args[0].name
577+
):
576578
expected_transformations.append(
577579
WeightShardingInfo(
578580
target_node=node.name,

0 commit comments

Comments
 (0)