Skip to content

Commit e5ea25e

Browse files
Fix ir generation for ttg.async_copy_global_to_local without mask (#7444)
We are calling `ir.value()` in the `async_copy_global_to_shared`, however the default constructor for `value` was never defined, leading to a crash.
1 parent ade3d49 commit e5ea25e

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

python/src/ir.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ void init_triton_ir(py::module &&m) {
349349
});
350350

351351
py::class_<Value>(m, "value", py::module_local())
352+
.def(py::init<>())
352353
.def("set_attr",
353354
[](Value &self, std::string &name, Attribute &attr) -> void {
354355
if (Operation *definingOp = self.getDefiningOp())

python/test/gluon/test_frontend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,7 +1094,7 @@ def async_copy_kernel(inp, xnumel, XBLOCK: ttgl.constexpr):
10941094
xindex = ttgl.arange(0, XBLOCK, block_layout)
10951095
mask = tl.max_constancy(xindex < xnumel, 2)
10961096

1097-
async_copy.async_copy_global_to_shared(smem, inp + xindex, mask)
1097+
async_copy.async_copy_global_to_shared(smem, inp + xindex)
10981098
async_copy.async_copy_global_to_shared(smem, inp + xindex, mask, cache_modifier=".ca", eviction_policy="evict_last",
10991099
volatile=True)
11001100

@@ -1124,7 +1124,7 @@ def test_async_copy(fresh_knobs):
11241124
%3 = arith.cmpi slt, %1, %2 {tt.constancy = dense<2> : tensor<1xi32>} : tensor<128xi32, #blocked> loc(#loc)
11251125
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x!tt.ptr<f16>, #blocked> loc(#loc)
11261126
%5 = tt.addptr %4, %1 : tensor<128x!tt.ptr<f16>, #blocked>, tensor<128xi32, #blocked> loc(#loc)
1127-
%6 = ttg.async_copy_global_to_local %5, %0 mask %3 : tensor<128x!tt.ptr<f16>, #blocked> -> <128xf16, #shared, #smem, mutable> loc(#loc)
1127+
%6 = ttg.async_copy_global_to_local %5, %0 : tensor<128x!tt.ptr<f16>, #blocked> -> <128xf16, #shared, #smem, mutable> loc(#loc)
11281128
%7 = ttg.async_copy_global_to_local %5, %0 mask %3 cacheModifier = ca evictionPolicy = evict_last {isVolatile = true} : tensor<128x!tt.ptr<f16>, #blocked> -> <128xf16, #shared, #smem, mutable> loc(#loc)
11291129
%8 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared, #smem, mutable> loc(#loc)
11301130
ttng.async_copy_mbarrier_arrive %8 : !ttg.memdesc<1xi64, #shared, #smem, mutable> loc(#loc)

0 commit comments

Comments
 (0)