Skip to content

Commit 6968a68

Browse files
henrylhtsangfacebook-github-bot
authored andcommitted
Small changes to improve blackwell_fmha_test.py (#4896)
Summary: X-link: facebookresearch/FBGEMM#1922 Pull Request resolved: #4896 Add some BE features: * fix seed * increase backward test to 200 * decrease backward test verbosity * improve error message when assertion fails Reviewed By: q10 Differential Revision: D81992869 fbshipit-source-id: b3e2f2e818f18c2d350cae7ab048c06dccf33b64
1 parent de1da63 commit 6968a68

File tree

1 file changed

+37
-13
lines changed

1 file changed

+37
-13
lines changed

fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@
2121
from .test_utils import attention_ref, generate_qkv, generate_random_padding_mask
2222

2323
common_settings = {
24-
"verbosity": Verbosity.verbose,
25-
"max_examples": 20,
24+
"verbosity": Verbosity.normal,
25+
"max_examples": 200,
2626
"deadline": None,
2727
"suppress_health_check": [HealthCheck.filter_too_much, HealthCheck.data_too_large],
2828
}
2929

3030
DEBUG = False
31+
SEED = 2
3132

3233
compute_capability = (0, 0)
3334
if torch.cuda.is_available():
@@ -50,21 +51,39 @@ def _allclose(
5051
t_pt: torch.Tensor,
5152
) -> None:
5253
assert t_test.shape == t_ref.shape == t_pt.shape
54+
55+
ratio = 2.0
56+
57+
# Calculate all differences
58+
test_ref_diff = self._abs_max(t_test - t_ref)
59+
test_pt_diff = self._abs_max(t_test - t_pt)
60+
pt_ref_diff = self._abs_max(t_pt - t_ref)
61+
5362
if DEBUG:
5463
# Debug: Print the differences
64+
print(f"DEBUG: Max absolute difference vs ref: {test_ref_diff}")
65+
print(f"DEBUG: Max absolute difference vs pt: {test_pt_diff}")
66+
print(f"DEBUG: Max absolute difference pt vs ref: {pt_ref_diff}")
5567
print(
56-
f"DEBUG: Max absolute difference vs ref: {self._abs_max(t_test - t_ref)}"
57-
)
58-
print(
59-
f"DEBUG: Max absolute difference vs pt: {self._abs_max(t_test - t_pt)}"
60-
)
61-
print(
62-
f"DEBUG: Max absolute difference pt vs ref: {self._abs_max(t_pt - t_ref)}"
63-
)
64-
print(
65-
f"DEBUG: Tolerance check: {self._abs_max(t_test - t_ref)} <= {2 * self._abs_max(t_pt - t_ref) + 1e-5}"
68+
f"DEBUG: Tolerance check: {test_ref_diff} <= {ratio * pt_ref_diff + 1e-5}"
6669
)
67-
assert self._abs_max(t_test - t_ref) <= 2 * self._abs_max(t_pt - t_ref) + 1e-4
70+
71+
# First assertion with gap information
72+
tolerance_threshold = ratio * pt_ref_diff + 1e-4
73+
assert test_ref_diff <= tolerance_threshold, (
74+
f"Tolerance check failed: max_diff={test_ref_diff:.6f} > "
75+
f"threshold={tolerance_threshold:.6f}, gap={test_ref_diff - tolerance_threshold:.6f}"
76+
)
77+
78+
# sanity checks
79+
assert test_ref_diff <= 0.5, (
80+
f"Max difference vs ref too large: {test_ref_diff:.6f} > 0.5, "
81+
f"gap={test_ref_diff - 0.5:.6f}"
82+
)
83+
assert pt_ref_diff <= 0.5, (
84+
f"Max difference pt vs ref too large: {pt_ref_diff:.6f} > 0.5, "
85+
f"gap={pt_ref_diff - 0.5:.6f}"
86+
)
6887

6988
def _generate_qkv(
7089
self,
@@ -121,6 +140,7 @@ def _execute_cutlass_blackwell_attn_dense(
121140
) -> None:
122141
device = torch.accelerator.current_accelerator()
123142
assert device is not None
143+
torch.manual_seed(SEED)
124144
assert seqlen_q <= seqlen_k
125145

126146
# Initialize deterministic variables
@@ -263,6 +283,8 @@ def _execute_cutlass_blackwell_attn_varlen(
263283
device = torch.accelerator.current_accelerator()
264284
assert device is not None
265285

286+
torch.manual_seed(SEED)
287+
266288
# Initialize deterministic variables
267289
out_unpad_d = None
268290
q_ref, k_ref, v_ref = self._generate_qkv(
@@ -501,6 +523,8 @@ def test_jagged_vs_padded_kv(
501523
head_dim = 128
502524
dtype = torch.bfloat16
503525

526+
torch.manual_seed(SEED)
527+
504528
# Create tensors
505529
q_padded = torch.randn(
506530
batch_size,

0 commit comments

Comments
 (0)