Skip to content

Commit bbc0df1

Browse files
njriasanpytorchmergebot
authored andcommitted
[Inductor][Triton] Support TMA before strict 3.4 cutoff (pytorch#159777)
Summary: Inductor's 3.4 Triton release is the most common used variant of Triton, but if someone is working with an alternative version of Triton this may not match. This moves the version check from 3.4 Triton to any variant that has support for the TMA APIs. Test Plan: Relying on CI. Should be a NFC. Rollback Plan: Reviewed By: davidberard98 Differential Revision: D79378792 Pull Request resolved: pytorch#159777 Approved by: https://github.com/davidberard98
1 parent 33ec6e3 commit bbc0df1

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

torch/_inductor/codegen/triton.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from torch._prims_common import is_integer_dtype
2727
from torch.utils._ordered_set import OrderedSet
2828
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
29-
from torch.utils._triton import has_triton_package
29+
from torch.utils._triton import has_triton_package, has_triton_stable_tma_api
3030

3131
from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT
3232
from ...utils._sympy.value_ranges import ValueRanges
@@ -1692,14 +1692,12 @@ def __post_init__(self):
16921692
def can_use_tma(
16931693
self,
16941694
) -> bool:
1695-
import triton
1696-
16971695
if not (
16981696
V.graph.get_current_device_or_throw().type == "cuda"
16991697
and torch.cuda.get_device_capability()[0] >= 9
17001698
and config.triton.use_tensor_descriptor
17011699
and config.assume_aligned_inputs
1702-
and triton.__version__ >= "3.4.0"
1700+
and has_triton_stable_tma_api()
17031701
# For CUDA The base ptr needs to be aligned
17041702
):
17051703
log.debug(

0 commit comments

Comments
 (0)