Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,9 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_m
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if (IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5) and
dtype in (torch.float16, torch.bfloat16) and rocm_attn_backend()[2]):
pytest.skip("Test is not supported on GFX950 with current parameters and CK fused attention backend and non-zero dropout.")

config = model_configs[model]

Expand Down Expand Up @@ -761,6 +764,9 @@ def test_gpt_full_activation_recompute(
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if (IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5) and
dtype in (torch.float16, torch.bfloat16) and rocm_attn_backend()[2]):
pytest.skip("Test is not supported on GFX950 with current parameters and CK fused attention backend and non-zero dropout.")

config = model_configs[model]
torch.compiler.reset() # avoid cache size limit overflow
Expand Down Expand Up @@ -909,6 +915,10 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
def test_gpt_checkpointing(dtype, bs, model):
if (IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5) and
dtype in (torch.float16, torch.bfloat16) and rocm_attn_backend()[2]):
pytest.skip("Test is not supported on GFX950 with current parameters and CK fused attention backend and non-zero dropout.")

config = model_configs[model]
outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
Expand Down Expand Up @@ -2410,6 +2420,9 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe):
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if (IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5) and
dtype in (torch.float16, torch.bfloat16) and rocm_attn_backend()[2]):
pytest.skip("Test is not supported on GFX950 with current parameters and CK fused attention backend and non-zero dropout.")

config = model_configs[model]

Expand Down Expand Up @@ -2561,10 +2574,21 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
max_seqlen_kv=config.seq_len,
)

torch.testing.assert_close(
y_bshd,
y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(),
)
if IS_HIP_EXTENSION:
# On some GPUs CK fused attention with THD can produce larger error
tols = dtype_tols(dtype)
tols["atol"] = 1e-3
torch.testing.assert_close(
y_bshd,
y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(),
**tols,
)
else:

torch.testing.assert_close(
y_bshd,
y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(),
)


@pytest.mark.parametrize("dtype", param_types)
Expand Down
Loading