Skip to content

Commit fb68276

Browse files
Merge commit '6af491923135061b107375f1716c7224b1807708'
2 parents 4706aaa + 6af4919 commit fb68276

File tree

18 files changed

+199
-85
lines changed

18 files changed

+199
-85
lines changed

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ scf::ForOp replaceForOpWithNewSignature(
141141
SmallVectorImpl<std::tuple<Value, Value>> &replacements);
142142
scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, scf::ForOp loop,
143143
ValueRange newIterOperands);
144-
Block::BlockArgListType addIterArgsToLoop(OpBuilder &rewriter, scf::ForOp &loop,
145-
ValueRange newIterOperands);
144+
[[nodiscard]] scf::ForOp addIterArgsToLoop(OpBuilder &rewriter, scf::ForOp loop,
145+
ValueRange newIterOperands);
146146

147147
// Replace WhileOp with a new WhileOp with extra operands. The YieldOp is not
148148
// updated and needs to be updated separately for the loop to be correct.

lib/Dialect/TritonGPU/Transforms/HoistTMEMAlloc.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,8 @@ ttng::TMEMAllocOp hoistTMEMAlloc(TMEMTokenAllocOp alloc, scf::ForOp &forOp) {
337337
// By hoisting the allocation out of the loop, we need to turn the underlying
338338
// memory variable into a loop-carried depdendency.
339339
auto tokType = builder.getType<AsyncTokenType>();
340-
Value newTok = addIterArgsToLoop(builder, forOp, newAlloc.getToken()).front();
340+
forOp = addIterArgsToLoop(builder, forOp, newAlloc.getToken());
341+
Value newTok = forOp.getRegionIterArgs().back();
341342
appendToForOpYield(forOp, joinLastMemoryUses(builder, alloc.getToken()));
342343

343344
if (src != nullptr) {

lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ class OptimizeAccumulatorInitPass
249249
}
250250

251251
Value loopArgFlagValue = loopArgIsZero ? vFalse : vTrue;
252-
(void)addIterArgsToLoop(rewriter, forOp, {loopArgFlagValue});
252+
forOp = addIterArgsToLoop(rewriter, forOp, {loopArgFlagValue});
253253
loopArgFlagValue =
254254
forOp.getRegionIterArg(forOp.getNumRegionIterArgs() - 1);
255255

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule) {
556556
}
557557

558558
// Patch the loop to add the new loop carried dependencies.
559-
(void)addIterArgsToLoop(builder, forOp, newOperands);
559+
forOp = addIterArgsToLoop(builder, forOp, newOperands);
560560

561561
// Update yield op with temporary yield values
562562
auto forYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
@@ -750,7 +750,7 @@ scf::ForOp lowerTMADescriptors(scf::ForOp forOp, CoarseSchedule &schedule) {
750750
newOperands.push_back(zero);
751751
}
752752

753-
(void)addIterArgsToLoop(builder, forOp, newOperands);
753+
forOp = addIterArgsToLoop(builder, forOp, newOperands);
754754

755755
auto tmaCounters = ArrayRef<BlockArgument>(forOp.getBody()->getArguments())
756756
.slice(tmaCounterArgsStartIdx);

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ Value triton::sinkValueRedefinition(RewriterBase &rewriter, Value in, Value out,
106106
// `in` is live into the loop body. `out` becomes the live-out if the
107107
// loop executes at least once.
108108
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
109-
(void)addIterArgsToLoop(rewriter, forOp, in);
109+
forOp = addIterArgsToLoop(rewriter, forOp, in);
110110
appendToForOpYield(forOp, out);
111111
out = forOp.getResults().back();
112112
continue;

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -682,17 +682,15 @@ scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, scf::ForOp loop,
682682
return newForOp;
683683
}
684684

685-
Block::BlockArgListType addIterArgsToLoop(OpBuilder &rewriter, scf::ForOp &loop,
686-
ValueRange newIterOperands) {
687-
unsigned curArgIdx = loop.getNumRegionIterArgs();
685+
scf::ForOp addIterArgsToLoop(OpBuilder &rewriter, scf::ForOp loop,
686+
ValueRange newIterOperands) {
688687
scf::ForOp newLoop =
689688
replaceForOpWithNewSignature(rewriter, loop, newIterOperands);
690689
// Save the caller from insertion point invalidation.
691690
if (rewriter.getInsertionPoint() == loop->getIterator())
692691
rewriter.setInsertionPoint(newLoop);
693692
loop.erase();
694-
loop = newLoop;
695-
return loop.getRegionIterArgs().slice(curArgIdx);
693+
return newLoop;
696694
}
697695

698696
scf::WhileOp replaceWhileOpWithNewSignature(

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ addIndexAndPhase(PartitionBuilder &b, scf::ForOp &loop, unsigned numStages,
118118
b.setInsertionPoint(loop);
119119

120120
// Index and phase both start at 0.
121-
unsigned curArgIdx = loop.getNumRegionIterArgs();
122-
auto newArgs = addIterArgsToLoop(b, loop, {b.intCst(0), b.intCst(0)});
121+
loop = addIterArgsToLoop(b, loop, {b.intCst(0), b.intCst(0)});
122+
auto newArgs = loop.getRegionIterArgs().take_back(2);
123123
BlockArgument index = newArgs[0];
124124
BlockArgument phase = newArgs[1];
125125

@@ -488,7 +488,8 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
488488
createTMemAlloc(b, oldAllocOp, /*multiBuffered=*/true, numMmaStages);
489489

490490
// Use placeholder values for the indices in the loop.
491-
auto indexPhase = addIterArgsToLoop(b, loop, {b.intCst(0), b.intCst(0)});
491+
loop = addIterArgsToLoop(b, loop, {b.intCst(0), b.intCst(0)});
492+
auto indexPhase = loop.getRegionIterArgs().take_back(2);
492493
BlockArgument index = indexPhase[0];
493494
BlockArgument phase = indexPhase[1];
494495

python/test/gluon/test_frontend.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -307,25 +307,25 @@ def anchor(x):
307307
@filecheck_test
308308
@gluon.jit
309309
def test_warp_specialize():
310-
# CHECK-LABEL: tt.func public @test_warp_specialize
310+
# CHECK-LABEL: test_warp_specialize
311311
# CHECK-NEXT: [[A:%.*]] = tt.make_range {end = 1 : i32, start = 0 : i32}
312312
# CHECK-NEXT: [[B:%.*]] = tt.make_range {end = 2 : i32, start = 0 : i32}
313313
# CHECK-NEXT: [[C:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
314314
# CHECK-NEXT: [[OUTS:%.*]]:3 = ttg.warp_specialize([[A]], [[B]], [[C]]) {{.*}}requestedRegisters = array<i32: 24, 48>
315315
# CHECK-NEXT: default {
316-
# CHECK-NEXT: [[RESULTS:%.*]]:3 = tt.call @"warp_specialize_default{{.*}}"([[A]], [[B]], [[C]])
316+
# CHECK-NEXT: [[RESULTS:%.*]]:3 = tt.call @{{.*}}warp_specialize_default{{.*}}([[A]], [[B]], [[C]])
317317
# CHECK-NEXT: warp_yield [[RESULTS]]#0, [[RESULTS]]#1, [[RESULTS]]#2
318318
# CHECK-NEXT: }
319319
# CHECK-NEXT: partition0(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>, %arg2: tensor<4xi32>) num_warps(4) {
320-
# CHECK-NEXT: call @"warp_specialize_worker0{{.*}}"(%arg0, %arg1, %arg2)
320+
# CHECK-NEXT: call @{{.*}}warp_specialize_worker0{{.*}}(%arg0, %arg1, %arg2)
321321
# CHECK-NEXT: warp_return
322322
# CHECK-NEXT: }
323323
# CHECK-NEXT: partition1(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>, %arg2: tensor<4xi32>) num_warps(4) {
324-
# CHECK-NEXT: call @"warp_specialize_worker1{{.*}}"(%arg0, %arg1, %arg2)
324+
# CHECK-NEXT: call @{{.*}}warp_specialize_worker1{{.*}}(%arg0, %arg1, %arg2)
325325
# CHECK-NEXT: warp_return
326326
# CHECK-NEXT: }
327-
# CHECK-NEXT: call @anchor{{.*}}([[OUTS]]#0)
328-
# CHECK-NEXT: call @"anchor{{.*}}"([[OUTS]]#1, [[OUTS]]#2)
327+
# CHECK-NEXT: call @{{.*}}anchor{{.*}}([[OUTS]]#0)
328+
# CHECK-NEXT: call @{{.*}}anchor{{.*}}([[OUTS]]#1, [[OUTS]]#2)
329329
pair = Pair(tl.arange(0, 1), tl.arange(0, 2))
330330
a, b = ttgl.warp_specialize((pair, tl.arange(0, 4)), warp_specialize_default,
331331
[warp_specialize_worker0, warp_specialize_worker1], [4, 4], [24, 48])
@@ -541,6 +541,29 @@ def kernel():
541541
assert "order must be a permutation of 0..(rank-1), but was [1]" in str(e.value.__cause__)
542542

543543

544+
@gluon.jit
545+
def tmem_subslice_kernel():
546+
layout: ttgl.constexpr = ttgl.nvidia.blackwell.TensorMemoryLayout(block=[128, 128], unpacked=True)
547+
tmem = ttgl.nvidia.blackwell.allocate_tensor_memory(ttgl.int32, [2, 256, 256], layout)
548+
tmem.subslice(0)
549+
550+
551+
def test_tmem_subslice_constexpr():
552+
expecttest.assert_expected_inline(
553+
run_parser(tmem_subslice_kernel).str_nodebug(), """\
554+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
555+
module {
556+
tt.func public @tmem_subslice_kernel() attributes {noinline = false} {
557+
%result = ttng.tmem_alloc : () -> !ttg.memdesc<2x256x256xi32, #tmem, #ttng.tensor_memory, mutable>
558+
%c0_i32 = arith.constant 0 : i32
559+
%c0_i32_0 = arith.constant 0 : i32
560+
%0 = ttg.memdesc_subview %result[%c0_i32, %c0_i32_0, %c0_i32_0] : !ttg.memdesc<2x256x256xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<256x256xi32, #tmem, #ttng.tensor_memory, mutable, 2x256x256>
561+
tt.return
562+
}
563+
}
564+
""")
565+
566+
544567
@gluon.jit
545568
def smem_and_layout_user(smem, a: ttgl.constexpr):
546569
pass
@@ -561,10 +584,10 @@ def kernel():
561584
module {
562585
tt.func public @kernel() attributes {noinline = false} {
563586
%0 = ttg.local_alloc : () -> !ttg.memdesc<32x32xi32, #shared, #smem, mutable>
564-
tt.call @"smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_constexpr[1]_constexpr[0]____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(constexpr_1_ ,constexpr_0_), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> ()
587+
tt.call @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_constexpr[1]_constexpr[0]____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(constexpr_1_ ,constexpr_0_), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> ()
565588
tt.return
566589
}
567-
tt.func private @"smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_constexpr[1]_constexpr[0]____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(constexpr_1_ ,constexpr_0_), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%arg0: !ttg.memdesc<32x32xi32, #shared, #smem, mutable>) attributes {noinline = false} {
590+
tt.func private @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_constexpr[1]_constexpr[0]____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(constexpr_1_ ,constexpr_0_), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%arg0: !ttg.memdesc<32x32xi32, #shared, #smem, mutable>) attributes {noinline = false} {
568591
tt.return
569592
}
570593
}

python/test/unit/language/test_frontend.py

Lines changed: 67 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_assign_attribute():
4242
scalar = 11
4343
pair = Pair(tl.arange(0, 4), scalar)
4444
# CHECK: %c42_i32 = arith.constant 42 : i32
45-
# CHECK-NEXT: call @"anchor{{.*}}"([[RANGE]], %c42_i32)
45+
# CHECK-NEXT: call @{{.*}}anchor{{.*}}([[RANGE]], %c42_i32)
4646
pair.second = 42
4747
anchor(pair)
4848

@@ -58,7 +58,7 @@ def test_augassign_attribute():
5858
# CHECK: %c42_i32 = arith.constant 42 : i32
5959
# CHECK: [[VALUE:%.*]] = arith.addi %c11_i32, %c42_i32
6060
pair.second += 42
61-
# CHECK-NEXT: call @"anchor{{.*}}"([[RANGE]], [[VALUE]])
61+
# CHECK-NEXT: call @{{.*}}anchor{{.*}}([[RANGE]], [[VALUE]])
6262
anchor(pair)
6363

6464

@@ -69,12 +69,12 @@ def test_jit_method():
6969
# CHECK: %c11_i32 = arith.constant 11 : i32
7070
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
7171
scalar = 11
72-
# CHECK: [[V:%.*]]:2 = tt.call @"unpack{{.*}}"([[RANGE]], %c11_i32)
72+
# CHECK: [[V:%.*]]:2 = tt.call @{{.*}}unpack{{.*}}([[RANGE]], %c11_i32)
7373
pair = Pair(tl.arange(0, 4), scalar)
7474
a, b = pair.unpack()
75-
# CHECK: call @anchor{{.*}}([[V]]#0)
75+
# CHECK: call @{{.*}}anchor{{.*}}([[V]]#0)
7676
anchor(a)
77-
# CHECK: call @anchor{{.*}}([[V]]#1)
77+
# CHECK: call @{{.*}}anchor{{.*}}([[V]]#1)
7878
anchor(b)
7979

8080

@@ -95,10 +95,10 @@ def test_aggregate_initializers():
9595
# CHECK-LABEL: test_aggregate_initializers
9696
value = TypeWithBuiltinInitializer()
9797
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
98-
# CHECK: call @"anchor{{.*}}"([[RANGE]])
98+
# CHECK: call @{{.*}}anchor{{.*}}([[RANGE]])
9999
anchor(value)
100100
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 8 : i32, start = 4 : i32}
101-
# CHECK: call @"anchor{{.*}}"([[RANGE]])
101+
# CHECK: call @{{.*}}anchor{{.*}}([[RANGE]])
102102
value.modify(tl.arange(4, 8))
103103
anchor(value)
104104

@@ -118,11 +118,11 @@ def list_of_functions_constexpr(arg, fns: tl.constexpr):
118118
@triton.jit
119119
def test_list_of_functions():
120120
# CHECK-LABEL: test_list_of_functions
121-
# CHECK: call @"list_of_functions_constexpr{{.*}}cJITFunction(test_frontend:anchor){{.*}}cJITFunction(test_frontend:forward)"
121+
# CHECK: call @{{.*}}list_of_functions_constexpr{{.*}}cJITFunction(test_frontend:anchor){{.*}}cJITFunction(test_frontend:forward)
122122

123-
# CHECK-LABEL: tt.func private @"list_of_functions_constexpr
124-
# CHECK-NEXT: call @anchor
125-
# CHECK-NEXT: call @forward
123+
# CHECK: tt.func private @{{.*}}list_of_functions_constexpr
124+
# CHECK-NEXT: call @{{.*}}anchor
125+
# CHECK-NEXT: call @{{.*}}forward
126126
list_of_functions_constexpr(tl.arange(0, 4), [anchor, forward])
127127

128128

@@ -138,6 +138,61 @@ def test_call_in_loop():
138138
# CHECK-LABEL: test_call_in_loop
139139
acc = 0
140140
# CHECK: scf.for
141-
# CHECK: call @accumulate
141+
# CHECK: call @{{.*}}accumulate
142142
for i in range(10):
143143
acc = accumulate(acc, i)
144+
145+
146+
@tl.core._aggregate
147+
class FunctionParent:
148+
149+
@triton.jit
150+
def function_with_name():
151+
pass
152+
153+
154+
@triton.jit
155+
def function_with_name():
156+
pass
157+
158+
159+
@filecheck_test
160+
@triton.jit
161+
def test_function_name_mangling():
162+
# CHECK-LABEL: test_function_name_mangling
163+
# CHECK: call @test_frontend.function_with_name
164+
# CHECK: call @test_frontend.FunctionParent.function_with_name
165+
function_with_name()
166+
FunctionParent.function_with_name()
167+
168+
169+
@tl.core._aggregate
170+
class AggregateWithConstexpr:
171+
a: tl.tensor
172+
b: tl.constexpr
173+
174+
def __init__(self, a, b):
175+
self.a = a
176+
self.b = b
177+
178+
@staticmethod
179+
def create(a):
180+
return AggregateWithConstexpr(a, tl.constexpr(42))
181+
182+
183+
@triton.jit
184+
def add_rhs_constexpr(agg):
185+
_ = agg.a + agg.b
186+
187+
188+
@filecheck_test
189+
@triton.jit
190+
def test_aggregate_with_constexpr():
191+
# CHECK-LABEL: test_aggregate_with_constexpr
192+
# CHECK: tt.call @"test_frontend.add_rhs_constexpr__test_frontend.AggregateWithConstexpr<i32S4S, constexpr[42]>
193+
agg = AggregateWithConstexpr.create(tl.arange(0, 4))
194+
add_rhs_constexpr(agg)
195+
196+
# CHECK: tt.func private @"test_frontend.add_rhs_constexpr__test_frontend.AggregateWithConstexpr<i32S4S, constexpr[42]>
197+
# CHECK: %cst = arith.constant dense<42> : tensor<4xi32>
198+
# CHECK: arith.addi %arg0, %cst : tensor<4xi32>
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import triton
2+
3+
4+
@triton.jit
5+
def function_with_name():
6+
pass

0 commit comments

Comments
 (0)