Skip to content

Commit 570a8e6

Browse files
fix unittest
Signed-off-by: Jaedeok Kim <[email protected]>
1 parent 318aba7 commit 570a8e6

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

tests/unittest/_torch/test_model_config.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,20 @@ def test_get_bindings_model_config_attention_dp_attn_tp_override(
6969
bindings_cfg = model_config.get_bindings_model_config(tokens_per_block=tokens_per_block)
7070

7171
# bindings hidden_size is sharded by attn_tp_size and attn_cp_size.
72-
assert bindings_cfg.num_heads == cfg.num_attention_heads // (
73-
mapping.attn_tp_size * mapping.attn_cp_size
74-
)
75-
# bindings hidden_size is sharded by tp_size (not attention TP size).
72+
attn_tp_size = mapping.attn_tp_size if not mapping.enable_attention_dp else 1
73+
attn_cp_size = mapping.attn_cp_size
74+
assert bindings_cfg.num_heads == cfg.num_attention_heads // (attn_tp_size * attn_cp_size)
75+
# bindings hidden_size is sharded by attn_tp_size.
7676
assert bindings_cfg.hidden_size == cfg.hidden_size // mapping.tp_size
7777
if isinstance(cfg.num_key_value_heads, (list, tuple)):
7878
expected_num_kv_heads_per_layer = [
79-
kv // (mapping.attn_tp_size * mapping.attn_cp_size) for kv in cfg.num_key_value_heads
79+
kv // (attn_tp_size * attn_cp_size) for kv in cfg.num_key_value_heads
8080
]
8181
assert list(bindings_cfg.num_kv_heads_per_layer) == expected_num_kv_heads_per_layer
8282
assert bindings_cfg.num_kv_heads(0) == expected_num_kv_heads_per_layer[0]
8383
else:
8484
assert bindings_cfg.num_kv_heads(0) == cfg.num_key_value_heads // (
85-
mapping.attn_tp_size * mapping.attn_cp_size
85+
attn_tp_size * attn_cp_size
8686
)
8787

8888
# tp_size-dependent value (uses mapping.tp_size, not attn_tp_size).

0 commit comments

Comments
 (0)