Skip to content

Commit 853e1da

Browse files
peterbell10FindHao
authored andcommitted
[FRONTEND] Cleanup nv_tma_desc_type (triton-lang#6508)
This is now dead code as `nv_tma_desc` only exists in the backend.
1 parent e5dda8b commit 853e1da

File tree

4 files changed

+8
-27
lines changed

4 files changed

+8
-27
lines changed

python/triton/compiler/code_generator.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .. import language
1212
from .._C.libtriton import ir
1313
from ..language import constexpr, semantic, str_to_ty, tensor
14-
from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type, base_value, base_type
14+
from ..language.core import _unwrap_if_constexpr, base_value, base_type
1515
from ..runtime.jit import get_jit_fn_file_line
1616
# ideally we wouldn't need any runtime component
1717
from ..runtime import JITFunction
@@ -258,10 +258,6 @@ def make_template(ty):
258258
for attr_name, attr_val in attr_specs:
259259
if attr_path in val_paths:
260260
fn.set_arg_attr(val_paths.index(attr_path), attr_name, attr_val)
261-
for i, path in enumerate(val_paths):
262-
ty = get_iterable_path(self.arg_types, path)
263-
if isinstance(ty, nv_tma_desc_type):
264-
fn.set_arg_attr(i, "tt.nv_tma_desc", 1)
265261
# > add IR values to the template
266262
cursor = 0
267263
handles = [fn.args(i) for i in range(fn.get_num_args())]

python/triton/language/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@
9090
permute,
9191
pi32_t,
9292
pointer_type,
93-
nv_tma_desc_type,
9493
program_id,
9594
range,
9695
reduce,
@@ -219,7 +218,6 @@
219218
"philox_impl",
220219
"pi32_t",
221220
"pointer_type",
222-
"nv_tma_desc_type",
223221
"program_id",
224222
"rand",
225223
"rand4x",
@@ -293,9 +291,6 @@ def str_to_ty(name):
293291
block = block_type(dtype, block_shape)
294292
return tensor_descriptor_type(block, shape_type, stride_type)
295293

296-
if name == "nvTmaDesc":
297-
return nv_tma_desc_type()
298-
299294
if name == "constexpr":
300295
return constexpr
301296

python/triton/language/core.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -643,13 +643,6 @@ def mangle(self) -> str:
643643
return f"P{self.element_ty.mangle()}"
644644

645645

646-
class nv_tma_desc_type(pointer_type):
647-
648-
def __init__(self, const=True, address_space=0):
649-
super().__init__(uint8, const=const, address_space=address_space)
650-
self.name = 'nv_tma_desc_type'
651-
652-
653646
class block_type(dtype):
654647

655648
def __init__(self, element_ty: dtype, shape: List):

python/tutorials/06-fused-attention.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,16 @@
1515

1616
import pytest
1717
import torch
18-
from triton.tools.tensor_descriptor import TensorDescriptor
1918

2019
import triton
2120
import triton.language as tl
2221

22+
try:
23+
from triton.tools.tensor_descriptor import TensorDescriptor
24+
HAS_TENSOR_DESC = True
25+
except ModuleNotFoundError:
26+
HAS_TENSOR_DESC = False
27+
2328
DEVICE = triton.runtime.driver.active.get_active_torch_device()
2429

2530

@@ -32,15 +37,7 @@ def is_cuda():
3237

3338

3439
def supports_tma():
35-
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
36-
37-
38-
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)
39-
40-
if HAS_TMA_DESC:
41-
print("TMA benchmarks will be running with experimental grid constant TMA descriptor.", )
42-
else:
43-
print("TMA benchmarks will be running without grid constant TMA descriptor.", )
40+
return HAS_TENSOR_DESC and is_cuda() and torch.cuda.get_device_capability()[0] >= 9
4441

4542

4643
@triton.jit

0 commit comments

Comments
 (0)