Skip to content

Commit 2b41842

Browse files
authored
[Triton] Fix python type annotations for descriptor functions (NFC) (#5567)
1 parent 4746ca9 commit 2b41842

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

python/triton/language/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,7 +1265,7 @@ def __str__(self) -> str:
12651265
return f"tensor_descriptor<{self.type}>"
12661266

12671267
@builtin
1268-
def load(self, offsets: List[tensor], _builder=None) -> tensor:
1268+
def load(self, offsets: List[constexpr | tensor], _builder=None) -> tensor:
12691269
"""Load a block from the descriptor starting at the given element offsets.
12701270
12711271
Values outside of the tensor bounds will be filled with zeros.
@@ -1275,7 +1275,7 @@ def load(self, offsets: List[tensor], _builder=None) -> tensor:
12751275
return semantic.descriptor_load(self, offsets, "", "", _builder)
12761276

12771277
@builtin
1278-
def store(self, offsets: List[tensor], value: tensor, _builder=None) -> tensor:
1278+
def store(self, offsets: List[constexpr | tensor], value: tensor, _builder=None) -> tensor:
12791279
"""Store a block from the descriptor starting at the given element offsets.
12801280
12811281
Values outside of the tensor bounds will be ignored.

python/triton/language/semantic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,7 +1147,7 @@ def reinterpret_tensor_descriptor(desc_ptr: tl.tensor, block_ty: tl.block_type,
11471147
return tl._experimental_tensor_descriptor_base(handle, block_ty)
11481148

11491149

1150-
def descriptor_load(desc: tl.tensor, offsets, cache_modifier: str, eviction_policy: str,
1150+
def descriptor_load(desc: tl._experimental_tensor_desciptor_base, offsets, cache_modifier: str, eviction_policy: str,
11511151
builder: ir.builder) -> tl.tensor:
11521152
assert isinstance(desc, tl._experimental_tensor_descriptor_base)
11531153
offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
@@ -1156,7 +1156,8 @@ def descriptor_load(desc: tl.tensor, offsets, cache_modifier: str, eviction_poli
11561156
return tl.tensor(x, desc.type)
11571157

11581158

1159-
def descriptor_store(desc: tl.tensor, value: tl.tensor, offsets, builder: ir.builder) -> tl.tensor:
1159+
def descriptor_store(desc: tl._experimental_tensor_descriptor_base, value: tl.tensor, offsets,
1160+
builder: ir.builder) -> tl.tensor:
11601161
assert isinstance(desc, tl._experimental_tensor_descriptor_base)
11611162
offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
11621163
return tl.tensor(builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void)

0 commit comments

Comments
 (0)