Skip to content

Commit e691ace

Browse files
committed
fixup! Disable TMA by default (#607)
1 parent 2eade97 commit e691ace

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

fla/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def map_triton_backend_to_torch_device() -> str:
399399
is_tf32_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8)
400400
is_gather_supported = hasattr(triton.language, 'gather')
401401
is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) \
402-
and os.environ.get('FLA_USE_TMA', '0') != '1' and \
402+
and os.environ.get('FLA_USE_TMA', '0') == '1' and \
403403
(hasattr(triton.language, '_experimental_make_tensor_descriptor') or hasattr(triton.language, 'make_tensor_descriptor'))
404404

405405
if is_nvidia and not is_tf32_supported:

0 commit comments

Comments
 (0)