Skip to content

Commit 7c32860

Browse files
Updated sharding tests
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent f3116e5 commit 7c32860

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def _run_sharding_execution_job(
237237
ssm_state_size=16, # Scaled from 128
238238
mamba_num_heads=num_heads,
239239
mamba_head_dim=num_features // num_heads, # 8
240-
n_groups=1, # Typical value
240+
n_groups=num_heads, # Typical value
241241
chunk_size=256,
242242
conv_kernel=4,
243243
use_conv_bias=bias,
@@ -388,7 +388,7 @@ def _run_pattern_detection_job(
388388
ssm_state_size=16,
389389
mamba_num_heads=num_heads,
390390
mamba_head_dim=num_features // num_heads,
391-
n_groups=1,
391+
n_groups=num_heads,
392392
chunk_size=256,
393393
conv_kernel=4,
394394
use_conv_bias=bias,

0 commit comments

Comments
 (0)