@@ -1499,13 +1499,13 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
1499
1499
args .append (enable_gqa )
1500
1500
params = SDPAParams (query , k_and_v .keys (), k_and_v .values (), mask , 0.0 , True , * args )
1501
1501
if expected is SDPBackend .FLASH_ATTENTION :
1502
- assert flash_sdp_enabled ()
1503
- assert can_use_flash_attention (params , True )
1502
+ assert flash_sdp_enabled (), "flash_sdp_enabled() is False"
1503
+ assert can_use_flash_attention (params , True ), "can_use_flash_attention(params, True) is False"
1504
1504
elif expected is SDPBackend .EFFICIENT_ATTENTION :
1505
- assert mem_efficient_sdp_enabled ()
1506
- assert can_use_efficient_attention (params , True )
1505
+ assert mem_efficient_sdp_enabled (), "mem_efficient_sdp_enabled() is False"
1506
+ assert can_use_efficient_attention (params , True ), "can_use_efficient_attention(params, True) is False"
1507
1507
elif expected is SDPBackend .MATH :
1508
- assert math_sdp_enabled ()
1508
+ assert math_sdp_enabled (), "math_sdp_enabled() is False"
1509
1509
else :
1510
1510
raise NotImplementedError
1511
1511
return original_fn (query , k_and_v , mask , return_scores )
0 commit comments