We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2eade97 commit e691aceCopy full SHA for e691ace
fla/utils.py
@@ -399,7 +399,7 @@ def map_triton_backend_to_torch_device() -> str:
399
is_tf32_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8)
400
is_gather_supported = hasattr(triton.language, 'gather')
401
is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) \
402
- and os.environ.get('FLA_USE_TMA', '0') != '1' and \
+ and os.environ.get('FLA_USE_TMA', '0') == '1' and \
403
(hasattr(triton.language, '_experimental_make_tensor_descriptor') or hasattr(triton.language, 'make_tensor_descriptor'))
404
405
if is_nvidia and not is_tf32_supported:
0 commit comments