diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 5a824e8c6..b89c5b7f3 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -305,8 +305,6 @@ def impl_test_context_parallel_attn( window_size=None, ): if qkv_layout.is_thd(): - if is_hip_extension() and cp_strategy == CPStrategy.RING: - pytest.skip("THD + ring on Rocm doesn't support context parallelism.") if cp_strategy == CPStrategy.ALL_GATHER: pytest.skip("THD doesn't support all gather context parallelism.") if not load_balanced and cp_strategy == CPStrategy.RING: