Skip to content

Commit 79ace62

Browse files
authored
[Pipeliner] Fix epilogue peeling for num_stages=3+ (#4890)
- The epilogue ramp-down indexing must start at zero or greater (total_iterations - max_stage) to ensure alignment with the prologue ramp-up stages. - If total_iterations < max_stage, the trailing stages will be masked. This commit mirrors upstream llvm/llvm-project#112418 and adds a functional test for correctness with num_stages=1,2,3,4.
1 parent 53c2965 commit 79ace62

File tree

3 files changed

+130
-104
lines changed

3 files changed

+130
-104
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -285,19 +285,6 @@ LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
285285
Location loc = forOp.getLoc();
286286
SmallVector<Value> predicates(maxStage);
287287
for (int64_t i = 0; i < maxStage; i++) {
288-
if (dynamicLoop) {
289-
Type t = ub.getType();
290-
// pred = ub > lb + (i * step)
291-
Value iv = rewriter.create<arith::AddIOp>(
292-
loc, lb,
293-
rewriter.create<arith::MulIOp>(
294-
loc, step,
295-
rewriter.create<arith::ConstantOp>(
296-
loc, rewriter.getIntegerAttr(t, i))));
297-
predicates[i] = rewriter.create<arith::CmpIOp>(
298-
loc, arith::CmpIPredicate::slt, iv, ub);
299-
}
300-
301288
// special handling for induction variable as the increment is implicit.
302289
// iv = lb + i * step
303290
Type t = lb.getType();
@@ -308,6 +295,13 @@ LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
308295
rewriter.create<arith::ConstantOp>(loc,
309296
rewriter.getIntegerAttr(t, i))));
310297
setValueMapping(forOp.getInductionVar(), iv, i);
298+
299+
if (dynamicLoop) {
300+
// pred = ub > lb + (i * step)
301+
predicates[i] = rewriter.create<arith::CmpIOp>(
302+
loc, arith::CmpIPredicate::slt, iv, ub);
303+
}
304+
311305
for (Operation *op : opOrder) {
312306
if (stages[op] > i)
313307
continue;
@@ -655,50 +649,56 @@ LogicalResult
655649
LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
656650
llvm::SmallVector<Value> &returnValues) {
657651
Location loc = forOp.getLoc();
652+
Type t = lb.getType();
658653
// Emit different versions of the induction variable. They will be
659654
// removed by dead code if not used.
660655

661-
// range_diff = ub - lb
662-
// total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
663-
Type t = lb.getType();
664-
Value zero =
665-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
666-
Value one =
667-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
668-
Value minusOne =
669-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
656+
auto createConst = [&](int v) {
657+
return rewriter.create<arith::ConstantOp>(loc,
658+
rewriter.getIntegerAttr(t, v));
659+
};
660+
661+
// total_iterations = cdiv(range_diff, step);
662+
// - range_diff = ub - lb
663+
// - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
664+
Value zero = createConst(0);
665+
Value one = createConst(1);
670666
Value stepLessZero = rewriter.create<arith::CmpIOp>(
671667
loc, arith::CmpIPredicate::slt, step, zero);
672668
Value stepDecr =
673-
rewriter.create<arith::SelectOp>(loc, stepLessZero, one, minusOne);
669+
rewriter.create<arith::SelectOp>(loc, stepLessZero, one, createConst(-1));
674670

675671
Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb);
676672
Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
677673
Value rangeDecr =
678674
rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
679675
Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step);
680676

677+
// If total_iters < max_stage, start the epilogue at zero to match the
678+
// ramp-up in the prologue.
679+
// start_iter = max(0, total_iters - max_stage)
680+
Value iterI = rewriter.create<arith::SubIOp>(loc, totalIterations,
681+
createConst(maxStage));
682+
iterI = rewriter.create<arith::MaxSIOp>(loc, zero, iterI);
683+
681684
// Capture predicates for dynamic loops.
682685
SmallVector<Value> predicates(maxStage + 1);
683686

684-
for (int64_t i = 0; i < maxStage; i++) {
685-
// iterI = total_iters - 1 - i
686-
// May go negative...
687-
Value minusI =
688-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
689-
Value iterI = rewriter.create<arith::AddIOp>(
690-
loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minusOne),
691-
minusI);
687+
for (int64_t i = 1; i <= maxStage; i++) {
692688
// newLastIter = lb + step * iterI
693689
Value newlastIter = rewriter.create<arith::AddIOp>(
694690
loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
695691

696-
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
692+
setValueMapping(forOp.getInductionVar(), newlastIter, i);
693+
694+
// increment to next iterI
695+
iterI = rewriter.create<arith::AddIOp>(loc, iterI, one);
697696

698697
if (dynamicLoop) {
699-
// pred = iterI >= 0
700-
predicates[i + 1] = rewriter.create<arith::CmpIOp>(
701-
loc, arith::CmpIPredicate::sge, iterI, zero);
698+
// Disable stages when `i` is greater than total_iters.
699+
// pred = total_iters >= i
700+
predicates[i] = rewriter.create<arith::CmpIOp>(
701+
loc, arith::CmpIPredicate::sge, totalIterations, createConst(i));
702702
}
703703
}
704704

python/test/unit/language/test_pipeliner.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,32 @@ def test_pipeline_vecadd(device):
180180
assert ttgir.count(f"num = {NUM_STAGES} : i32") != 0, "num_stages not match"
181181
# 3. check alloc
182182
assert ttgir.count("triton_gpu.local_alloc") == 2, "alloc number not match"
183+
184+
185+
@pytest.mark.parametrize("ROW_COUNT", [0, 1, 2, 3])
186+
@pytest.mark.parametrize("NUM_STAGES", [1, 2, 3, 4, 5])
187+
def test_pipeline_epilogue(ROW_COUNT, NUM_STAGES, device):
188+
189+
@triton.jit
190+
def kernel_up(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
191+
NUM_STAGES: tl.constexpr):
192+
row_step = tl.num_programs(0)
193+
col_offsets = tl.arange(0, BLOCK_SIZE)
194+
mask = col_offsets < n_cols
195+
for row_idx in tl.range(0, n_rows, row_step, num_stages=NUM_STAGES):
196+
row_start_ptr = input_ptr + row_idx * input_row_stride
197+
input_ptrs = row_start_ptr + col_offsets
198+
val = tl.load(input_ptrs, mask=mask, other=-float('inf'))
199+
val += 1.0
200+
output_row_start_ptr = output_ptr + row_idx * output_row_stride
201+
output_ptrs = output_row_start_ptr + col_offsets
202+
tl.store(output_ptrs, val, mask=mask)
203+
204+
width = ROW_COUNT
205+
depth = 78
206+
x = torch.zeros(width, depth, device='cuda')
207+
y0 = torch.rand_like(x)
208+
n_rows, n_cols = x.shape
209+
BLOCK_SIZE = triton.next_power_of_2(n_cols)
210+
kernel_up[(1, )](y0, x, x.stride(0), y0.stride(0), n_rows, n_cols, BLOCK_SIZE, NUM_STAGES)
211+
assert (y0 == torch.ones_like(x)).all()

test/TritonGPU/loop-pipeline.mlir

Lines changed: 66 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -84,19 +84,18 @@
8484
// AMD: %[[SUBI_23:.*]] = arith.subi %[[UB]], %[[LB]]
8585
// AMD: %[[ADDI_24:.*]] = arith.addi %[[SUBI_23]], %[[STEP]]
8686
// AMD: %[[ADDI_25:.*]] = arith.addi %[[ADDI_24]], %[[SELECT_22]]
87-
// AMD: %[[DIVUI_26:.*]] = arith.divsi %[[ADDI_25]], %[[STEP]]
88-
// AMD: %[[ADDI_27:.*]] = arith.addi %[[DIVUI_26]], %[[CM1]]
89-
// AMD: %[[CMPI_28:.*]] = arith.cmpi sge, %[[ADDI_27]], %[[C0]]
90-
// AMD: %[[LOCAL_LOAD_27:.*]] = triton_gpu.local_load %[[FOR]]#4
91-
// AMD: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %[[FOR]]#5
92-
// AMD: %[[MULF_29:.*]] = arith.mulf %[[LOCAL_LOAD_28]], %{{.*}}
93-
// AMD: %[[IF_30:.*]] = scf.if %[[CMPI_28]]
94-
// AMD: %[[DOT_32:.*]] = tt.dot %[[LOCAL_LOAD_27]], %[[MULF_29]], %[[FOR]]#2
95-
// AMD: scf.yield %[[DOT_32]]
87+
// AMD: %[[DIVSI_26:.*]] = arith.divsi %[[ADDI_25]], %[[STEP]]
88+
// AMD: %[[CMPI_27:.*]] = arith.cmpi sge, %[[DIVSI_26]], %{{.*}}
89+
// AMD: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %{{.*}}#4
90+
// AMD: %[[LOCAL_LOAD_29:.*]] = triton_gpu.local_load %{{.*}}#5
91+
// AMD: %[[MULF_30:.*]] = arith.mulf %[[LOCAL_LOAD_29]], %{{.*}}
92+
// AMD: %[[IF_31:.*]] = scf.if %[[CMPI_27]]
93+
// AMD: %[[DOT_33:.*]] = tt.dot %[[LOCAL_LOAD_28]], %[[MULF_30]], %{{.*}}#2
94+
// AMD: scf.yield %[[DOT_33]]
9695
// AMD: } else {
97-
// AMD: scf.yield %[[FOR]]#2
96+
// AMD: scf.yield %{{.*}}#2
9897
// AMD: }
99-
// AMD: %[[SELECT_31:.*]] = arith.select %[[CMPI_28]], %[[IF_30]], %[[FOR]]#2
98+
// AMD: %[[SELECT_32:.*]] = arith.select %[[CMPI_27]], %[[IF_31]], %{{.*}}#2
10099
// AMD: triton_gpu.local_dealloc %{{.*}}
101100
// AMD: triton_gpu.local_dealloc %{{.*}}
102101

@@ -414,35 +413,33 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
414413
// AMD: triton_gpu.local_store %[[ARG14]], %[[MEMDESC_SUBVIEW_58]]
415414
// AMD: scf.yield %[[DOT_45]], %[[ADDPTR_46]], %[[ADDPTR_47]], %[[SELECT_56]], %[[MEMDESC_SUBVIEW_57]], %[[MEMDESC_SUBVIEW_58]], %[[LOAD_48]], %[[LOAD_53]]
416415
// AMD: }
417-
// AMD: %[[ADDI_26:.*]] = arith.addi %{{.*}}, %{{.*}}-1
418-
// AMD: %[[CMPI_27:.*]] = arith.cmpi sge, %[[ADDI_26]], %{{.*}}
419-
// AMD: %[[ADDI_28:.*]] = arith.addi %{{.*}}, %{{.*}}-2
420-
// AMD: %[[CMPI_29:.*]] = arith.cmpi sge, %[[ADDI_28]], %{{.*}}
421-
// AMD: %[[LOCAL_LOAD_30:.*]] = triton_gpu.local_load %{{.*}}#4
422-
// AMD: %[[LOCAL_LOAD_31:.*]] = triton_gpu.local_load %{{.*}}#5
423-
// AMD: %[[IF_32:.*]] = scf.if %[[CMPI_27]]
424-
// AMD: %[[DOT_43:.*]] = tt.dot %[[LOCAL_LOAD_30]], %[[LOCAL_LOAD_31]], %{{.*}}#0
425-
// AMD: scf.yield %[[DOT_43]]
416+
// AMD: %[[CMPI_26:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}}
417+
// AMD: %[[CMPI_27:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}}
418+
// AMD: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %{{.*}}#4
419+
// AMD: %[[LOCAL_LOAD_29:.*]] = triton_gpu.local_load %{{.*}}#5
420+
// AMD: %[[IF_30:.*]] = scf.if %[[CMPI_26]]
421+
// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_28]], %[[LOCAL_LOAD_29]], %{{.*}}#0
422+
// AMD: scf.yield %[[DOT_41]]
426423
// AMD: } else {
427-
// AMD: scf.yield %{{.*}}#0
424+
// AMD: scf.yield %{{.*}}#0
428425
// AMD: }
429-
// AMD: %[[ADDI_33:.*]] = arith.addi %{{.*}}#3, %{{.*}}
430-
// AMD: %[[CMPI_34:.*]] = arith.cmpi slt, %[[ADDI_33]], %{{.*}}
431-
// AMD: %[[SELECT_35:.*]] = arith.select %[[CMPI_34]], %[[ADDI_33]], %{{.*}}
432-
// AMD: %[[MEMDESC_SUBVIEW_36:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_35]], %{{.*}}, %{{.*}}]
433-
// AMD: triton_gpu.local_store %{{.*}}#6, %[[MEMDESC_SUBVIEW_36]]
434-
// AMD: %[[MEMDESC_SUBVIEW_37:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_35]], %{{.*}}, %{{.*}}]
435-
// AMD: triton_gpu.local_store %{{.*}}#7, %[[MEMDESC_SUBVIEW_37]]
436-
// AMD: %[[SELECT_38:.*]] = arith.select %[[CMPI_27]], %[[IF_32]], %{{.*}}#0
437-
// AMD: %[[LOCAL_LOAD_39:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_36]]
438-
// AMD: %[[LOCAL_LOAD_40:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_37]]
439-
// AMD: %[[IF_41:.*]] = scf.if %[[CMPI_29]]
440-
// AMD: %[[DOT_43:.*]] = tt.dot %[[LOCAL_LOAD_39]], %[[LOCAL_LOAD_40]], %[[SELECT_38]]
441-
// AMD: scf.yield %[[DOT_43]]
426+
// AMD: %[[ADDI_31:.*]] = arith.addi %{{.*}}#3, %{{.*}}
427+
// AMD: %[[CMPI_32:.*]] = arith.cmpi slt, %[[ADDI_31]], %{{.*}}
428+
// AMD: %[[SELECT_33:.*]] = arith.select %[[CMPI_32]], %[[ADDI_31]], %{{.*}}
429+
// AMD: %[[MEMDESC_SUBVIEW_34:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_33]], %{{.*}}, %{{.*}}]
430+
// AMD: triton_gpu.local_store %{{.*}}#6, %[[MEMDESC_SUBVIEW_34]]
431+
// AMD: %[[MEMDESC_SUBVIEW_35:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_33]], %{{.*}}, %{{.*}}]
432+
// AMD: triton_gpu.local_store %{{.*}}#7, %[[MEMDESC_SUBVIEW_35]]
433+
// AMD: %[[SELECT_36:.*]] = arith.select %[[CMPI_26]], %[[IF_30]], %{{.*}}#0
434+
// AMD: %[[LOCAL_LOAD_37:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_34]]
435+
// AMD: %[[LOCAL_LOAD_38:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_35]]
436+
// AMD: %[[IF_39:.*]] = scf.if %[[CMPI_27]]
437+
// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_37]], %[[LOCAL_LOAD_38]], %[[SELECT_36]]
438+
// AMD: scf.yield %[[DOT_41]]
442439
// AMD: } else {
443-
// AMD: scf.yield %[[SELECT_38]]
440+
// AMD: scf.yield %[[SELECT_36]]
444441
// AMD: }
445-
// AMD: %[[SELECT_42:.*]] = arith.select %[[CMPI_29]], %[[IF_41]], %[[SELECT_38]]
442+
// AMD: %[[SELECT_40:.*]] = arith.select %[[CMPI_27]], %[[IF_39]], %[[SELECT_36]]
446443
// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_0]]
447444
// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_1]]
448445

@@ -976,6 +973,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
976973

977974
// AMD-DIS: #[[$SHARED_LAYOUT:shared.*]] = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}>
978975
// AMD-LABEL: tt.func @indirect_load_shared_layout
976+
// AMD: %[[LOCAL_ALLOC_0:.*]] = triton_gpu.local_alloc
977+
// AMD: %[[LOCAL_ALLOC_1:.*]] = triton_gpu.local_alloc
979978
// AMD: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}})
980979
// AMD: %[[LOCAL_LOAD_47:.*]] = triton_gpu.local_load %[[ARG11]]
981980
// AMD: %[[LOCAL_LOAD_48:.*]] = triton_gpu.local_load %[[ARG12]]
@@ -998,44 +997,42 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
998997
// AMD: triton_gpu.local_store %[[LOAD_57]], %[[MEMDESC_SUBVIEW_63]]
999998
// AMD: scf.yield %[[DOT_49]], %[[ADDPTR_50]], %[[ADDPTR_51]], %[[SELECT_61]], %[[MEMDESC_SUBVIEW_62]], %[[MEMDESC_SUBVIEW_63]], %[[LOAD_58]]
1000999
// AMD: }
1001-
// AMD: %[[ADDI_21:.*]] = arith.addi %{{.*}}, %{{.*}}-1
1002-
// AMD: %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_21]], %{{.*}}
1003-
// AMD: %[[ADDI_23:.*]] = arith.addi %{{.*}}, %{{.*}}-2
1004-
// AMD: %[[CMPI_24:.*]] = arith.cmpi sge, %[[ADDI_23]], %{{.*}}
1005-
// AMD: %[[LOCAL_LOAD_25:.*]] = triton_gpu.local_load %{{.*}}#4
1006-
// AMD: %[[LOCAL_LOAD_26:.*]] = triton_gpu.local_load %{{.*}}#5
1007-
// AMD: %[[IF_27:.*]] = scf.if %[[CMPI_22]]
1008-
// AMD: %[[DOT_47:.*]] = tt.dot %[[LOCAL_LOAD_25]], %[[LOCAL_LOAD_26]], %{{.*}}#0
1009-
// AMD: scf.yield %[[DOT_47]]
1000+
// AMD: %[[CMPI_21:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}}
1001+
// AMD: %[[CMPI_22:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}}
1002+
// AMD: %[[LOCAL_LOAD_23:.*]] = triton_gpu.local_load %{{.*}}#4
1003+
// AMD: %[[LOCAL_LOAD_24:.*]] = triton_gpu.local_load %{{.*}}#5
1004+
// AMD: %[[IF_25:.*]] = scf.if %[[CMPI_21]]
1005+
// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_23]], %[[LOCAL_LOAD_24]], %{{.*}}#0
1006+
// AMD: scf.yield %[[DOT_45]]
10101007
// AMD: } else {
10111008
// AMD: scf.yield %{{.*}}#0
10121009
// AMD: }
1013-
// AMD: %[[ADDPTR_28:.*]] = tt.addptr %{{.*}}#1, %{{.*}}
1014-
// AMD: %[[SPLAT_29:.*]] = tt.splat %[[CMPI_24]]
1015-
// AMD: %[[LOAD_30:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_29]]
1016-
// AMD: %[[EXPAND_DIMS_31:.*]] = tt.expand_dims %{{.*}}#6 {axis = 1 : i32}
1017-
// AMD: %[[BROADCAST_32:.*]] = tt.broadcast %[[EXPAND_DIMS_31]]
1018-
// AMD: %[[MULI_33:.*]] = arith.muli %{{.*}}, %[[BROADCAST_32]]
1019-
// AMD: %[[ADDPTR_34:.*]] = tt.addptr %{{.*}}, %[[MULI_33]]
1020-
// AMD: %[[SPLAT_35:.*]] = tt.splat %[[CMPI_24]]
1021-
// AMD: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_34]], %[[SPLAT_35]]
1022-
// AMD: %[[ADDI_37:.*]] = arith.addi %{{.*}}#3, %{{.*}}
1023-
// AMD: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}}
1024-
// AMD: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}}
1025-
// AMD: %[[MEMDESC_SUBVIEW_40:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}]
1026-
// AMD: triton_gpu.local_store %[[LOAD_30]], %[[MEMDESC_SUBVIEW_40]]
1027-
// AMD: %[[MEMDESC_SUBVIEW_41:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}]
1028-
// AMD: triton_gpu.local_store %[[LOAD_36]], %[[MEMDESC_SUBVIEW_41]]
1029-
// AMD: %[[SELECT_42:.*]] = arith.select %[[CMPI_22]], %[[IF_27]], %{{.*}}#0
1030-
// AMD: %[[LOCAL_LOAD_43:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_40]]
1031-
// AMD: %[[LOCAL_LOAD_44:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_41]]
1032-
// AMD: %[[IF_45:.*]] = scf.if %[[CMPI_24]]
1033-
// AMD: %[[DOT_47:.*]] = tt.dot %[[LOCAL_LOAD_43]], %[[LOCAL_LOAD_44]], %[[SELECT_42]]
1034-
// AMD: scf.yield %[[DOT_47]]
1010+
// AMD: %[[ADDPTR_26:.*]] = tt.addptr %{{.*}}#1, %{{.*}}
1011+
// AMD: %[[SPLAT_27:.*]] = tt.splat %[[CMPI_22]]
1012+
// AMD: %[[LOAD_28:.*]] = tt.load %[[ADDPTR_26]], %[[SPLAT_27]]
1013+
// AMD: %[[EXPAND_DIMS_29:.*]] = tt.expand_dims %{{.*}}#6 {axis = 1 : i32}
1014+
// AMD: %[[BROADCAST_30:.*]] = tt.broadcast %[[EXPAND_DIMS_29]]
1015+
// AMD: %[[MULI_31:.*]] = arith.muli %{{.*}}, %[[BROADCAST_30]]
1016+
// AMD: %[[ADDPTR_32:.*]] = tt.addptr %{{.*}}, %[[MULI_31]]
1017+
// AMD: %[[SPLAT_33:.*]] = tt.splat %[[CMPI_22]]
1018+
// AMD: %[[LOAD_34:.*]] = tt.load %[[ADDPTR_32]], %[[SPLAT_33]]
1019+
// AMD: %[[ADDI_35:.*]] = arith.addi %{{.*}}#3, %{{.*}}
1020+
// AMD: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ADDI_35]], %{{.*}}
1021+
// AMD: %[[SELECT_37:.*]] = arith.select %[[CMPI_36]], %[[ADDI_35]], %{{.*}}
1022+
// AMD: %[[MEMDESC_SUBVIEW_38:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_37]], %{{.*}}, %{{.*}}]
1023+
// AMD: triton_gpu.local_store %[[LOAD_28]], %[[MEMDESC_SUBVIEW_38]]
1024+
// AMD: %[[MEMDESC_SUBVIEW_39:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_37]], %{{.*}}, %{{.*}}]
1025+
// AMD: triton_gpu.local_store %[[LOAD_34]], %[[MEMDESC_SUBVIEW_39]]
1026+
// AMD: %[[SELECT_40:.*]] = arith.select %[[CMPI_21]], %[[IF_25]], %{{.*}}#0
1027+
// AMD: %[[LOCAL_LOAD_41:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_38]]
1028+
// AMD: %[[LOCAL_LOAD_42:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_39]]
1029+
// AMD: %[[IF_43:.*]] = scf.if %[[CMPI_22]]
1030+
// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_41]], %[[LOCAL_LOAD_42]], %[[SELECT_40]]
1031+
// AMD: scf.yield %[[DOT_45]]
10351032
// AMD: } else {
1036-
// AMD: scf.yield %[[SELECT_42]]
1033+
// AMD: scf.yield %[[SELECT_40]]
10371034
// AMD: }
1038-
// AMD: %[[SELECT_46:.*]] = arith.select %[[CMPI_24]], %[[IF_45]], %[[SELECT_42]]
1035+
// AMD: %[[SELECT_44:.*]] = arith.select %[[CMPI_22]], %[[IF_43]], %[[SELECT_40]]
10391036
// AMD: triton_gpu.local_dealloc %{{.*}}
10401037
// AMD: triton_gpu.local_dealloc %{{.*}}
10411038

0 commit comments

Comments
 (0)