File tree Expand file tree Collapse file tree 1 file changed +6
-6
lines changed
Expand file tree Collapse file tree 1 file changed +6
-6
lines changed Original file line number Diff line number Diff line change @@ -69,20 +69,20 @@ def test_get_bindings_model_config_attention_dp_attn_tp_override(
6969 bindings_cfg = model_config .get_bindings_model_config (tokens_per_block = tokens_per_block )
7070
7171 # bindings hidden_size is sharded by attn_tp_size and attn_cp_size.
72- assert bindings_cfg . num_heads == cfg . num_attention_heads // (
73- mapping . attn_tp_size * mapping .attn_cp_size
74- )
75- # bindings hidden_size is sharded by tp_size (not attention TP size) .
72+ attn_tp_size = mapping . attn_tp_size if not mapping . enable_attention_dp else 1
73+ attn_cp_size = mapping .attn_cp_size
74+ assert bindings_cfg . num_heads == cfg . num_attention_heads // ( attn_tp_size * attn_cp_size )
75+ # bindings hidden_size is sharded by attn_tp_size .
7676 assert bindings_cfg .hidden_size == cfg .hidden_size // mapping .tp_size
7777 if isinstance (cfg .num_key_value_heads , (list , tuple )):
7878 expected_num_kv_heads_per_layer = [
79- kv // (mapping . attn_tp_size * mapping . attn_cp_size ) for kv in cfg .num_key_value_heads
79+ kv // (attn_tp_size * attn_cp_size ) for kv in cfg .num_key_value_heads
8080 ]
8181 assert list (bindings_cfg .num_kv_heads_per_layer ) == expected_num_kv_heads_per_layer
8282 assert bindings_cfg .num_kv_heads (0 ) == expected_num_kv_heads_per_layer [0 ]
8383 else :
8484 assert bindings_cfg .num_kv_heads (0 ) == cfg .num_key_value_heads // (
85- mapping . attn_tp_size * mapping . attn_cp_size
85+ attn_tp_size * attn_cp_size
8686 )
8787
8888 # tp_size-dependent value (uses mapping.tp_size, not attn_tp_size).
You can’t perform that action at this time.
0 commit comments