diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index aa7a8fb8d5..6d963f5c7b 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -12,7 +12,7 @@ from transformer_engine.jax.sharding import MeshResource -from utils import assert_allclose, is_devices_enough +from utils import assert_allclose, is_devices_enough, is_devices_equal def generate_configs(): @@ -49,7 +49,11 @@ def generate_context_parallel_configs_for_attn(): TP_sizes = (1, 2) for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes): ndev = cp * tp * dp - if is_devices_enough(ndev): + # Run only those dp,cp,tp combinations which require exactly ndev GPUs. + # For e.g., if num_GPUs is 8 and ndev=8 , all the dp,cp,tp combinations fulfilling ndev = cp * tp * dp are picked. + # However, if num_GPUs is 8 and ndev=4, then all the dp,cp,tp combinations fulfilling ndev = cp * tp * dp are ignored. + # To explicitly pick combinations associated with ndev=4, one can set CUDA_VISIBLE_DEVICES=0,1,2,3, thereby forcing num_GPUs to 4 instead of 8. + if is_devices_equal(ndev): # Do not run cp1 case in L1 as that is already covered in TestDistributedSelfAttn and TestDistributedCrossAttn (as these do not have any cp combinations) if cp != 1: configsL1.append( diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index d0018543d1..d5ebe9f261 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -334,7 +334,7 @@ def test_cross_attn( class TestDistributedContextParallelSelfAttn: - + # TODO(KshitijLakhani): parametrize num_segments_per_seq for all CP tests def impl_test_context_parallel_attn( self, device_count, diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index ac1b7c3505..a0aee50430 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1068,41 +1068,70 @@ def check_dqkv(primitive, reference, pad, idx): ], ) @pytest.mark.parametrize( - "qkv_layout", - [ - pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"), - pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"), - pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"), - pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"), - pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"), - pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"), - ], -) -@pytest.mark.parametrize( - "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype", + "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout", [ + # large data size + bf16 + qkv packed pytest.param( - 2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-64-BF16-SELF" + 2, + 2048, + 2048, + 12, + 12, + 64, + 64, + jnp.bfloat16, + QKVLayout.BS3HD, + id="2-2048-2048-12-12-64-64-BF16-SELF-QKV_PACKED", ), pytest.param( 2, - 512, - 1024, + 2048, + 2048, 12, 12, 64, 64, jnp.bfloat16, - id="2-512-1024-12-12-64-64-BF16-CROSS", + QKVLayout.T3HD, + id="2-2048-2048-12-12-64-64-BF16-SELF-RAGGED_QKV_PACKED", ), + # mid data size + bf16 + cross attn + kv packed pytest.param( - 2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA" + 2, + 512, + 1024, + 12, + 12, + 64, + 64, + jnp.bfloat16, + QKVLayout.BSHD_BS2HD, + id="2-512-1024-12-12-64-64-BF16-CROSS-KV_PACKED", ), pytest.param( - 4, 128, 128, 16, 16, 64, 64, jnp.float16, id="4-128-128-16-16-64-64-FP16-SELF" + 2, + 512, + 1024, + 12, + 12, + 64, + 64, + jnp.bfloat16, + QKVLayout.THD_T2HD, + id="2-512-1024-12-12-64-64-BF16-CROSS-RAGGED_KV_PACKED", ), + # large data size + bf16 + cross attn + diff hidden v dim + qkv separate pytest.param( - 4, 128, 128, 16, 16, 64, 32, jnp.float16, id="4-128-128-16-16-64-32-FP16-SELF" + 2, + 2048, + 1024, + 12, + 12, + 64, + 32, + jnp.bfloat16, + QKVLayout.BSHD_BSHD_BSHD, + id="2-2048-1024-12-12-64-32-BF16-CROSS-SEPARATE", ), pytest.param( 2, @@ -1113,10 +1142,108 @@ def check_dqkv(primitive, reference, pad, idx): 64, 32, jnp.bfloat16, - id="2-2048-1024-12-12-64-32-BF16-CROSS", + QKVLayout.THD_THD_THD, + id="2-2048-1024-12-12-64-32-BF16-CROSS-RAGGED_SEPARATE", + ), + # large data size + bf16 + gqa + kv packed + pytest.param( + 2, + 2048, + 2048, + 12, + 6, + 64, + 64, + jnp.bfloat16, + QKVLayout.BSHD_BS2HD, + id="2-2048-2048-12-6-64-64-BF16-GQA-KV_PACKED", + ), + pytest.param( + 2, + 2048, + 2048, + 12, + 6, + 64, + 64, + jnp.bfloat16, + QKVLayout.THD_T2HD, + id="2-2048-2048-12-6-64-64-BF16-GQA-RAGGED_KV_PACKED", ), + # small data size + fp16 + diff hidden v dim + qkv packed pytest.param( - 2, 2048, 2048, 12, 6, 128, 64, jnp.float16, id="2-2048-2048-12-6-128-64-FP16-GQA" + 4, + 128, + 128, + 16, + 16, + 64, + 32, + jnp.float16, + QKVLayout.BS3HD, + id="4-128-128-16-16-64-32-FP16-SELF-QKV_PACKED", + ), + pytest.param( + 4, + 128, + 128, + 16, + 16, + 64, + 32, + jnp.float16, + QKVLayout.T3HD, + id="4-128-128-16-16-64-32-FP16-SELF-RAGGED_QKV_PACKED", + ), + # small data size + fp16 + kv packed + pytest.param( + 4, + 128, + 128, + 16, + 16, + 64, + 64, + jnp.float16, + QKVLayout.BSHD_BS2HD, + id="4-128-128-16-16-64-64-FP16-SELF-KV_PACKED", + ), + pytest.param( + 4, + 128, + 128, + 16, + 16, + 64, + 64, + jnp.float16, + QKVLayout.THD_T2HD, + id="4-128-128-16-16-64-64-FP16-SELF-RAGGED_KV_PACKED", + ), + # large data size + fp16 + cross attn + gqa + diff hidden v dim + qkv separate + pytest.param( + 2, + 1024, + 2048, + 12, + 6, + 128, + 64, + jnp.float16, + QKVLayout.BSHD_BSHD_BSHD, + id="2-1024-2048-12-6-128-64-FP16-CROSS-GQA-SEPARATE", + ), + pytest.param( + 2, + 1024, + 2048, + 12, + 6, + 128, + 64, + jnp.float16, + QKVLayout.THD_THD_THD, + id="2-1024-2048-12-6-128-64-FP16-CROSS-GQA-RAGGED_SEPARATE", ), ], ) diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 8055792308..c22b0a6063 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -47,6 +47,13 @@ def is_devices_enough(required): return len(jax.devices()) >= required +def is_devices_equal(required): + """ + Check if the available GPUs is exactly equal + """ + return len(jax.devices()) == required + + def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]: # Generate broadcast dims for drop_path. drop_path_shape = list(range(0, len(shape)))