Skip to content

Commit a35e8b3

Browse files
Merge commit 'c6ee6266538f252035b8836643b5fa05fa61b707'
2 parents 7f1d43e + c6ee626 commit a35e8b3

File tree

9 files changed

+86
-36
lines changed

9 files changed

+86
-36
lines changed

bench/triton_bench/swiglu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def forward(ctx, a, alpha, precision_config, routing_data):
6969
n_tokens,
7070
BLOCK_M=BLOCK_M,
7171
BLOCK_N=BLOCK_N,
72-
EVEN_N=(N // 2) % 2 == 0,
72+
EVEN_N=(N // 2) % BLOCK_N == 0,
7373
M_BLOCKS=M_BLOCKS,
7474
N_BLOCKS=N_BLOCKS,
7575
flexpoint_saturate_inf=flex_ctx.saturate_inf,

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1174,7 +1174,7 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
11741174
// in this synchronization edge.
11751175
decltype(operandDefs) nextOperandDefs;
11761176
for (auto &[defOp, defPartition] : operandDefs) {
1177-
if (defPartition == partition)
1177+
if (defPartition == partition && inBody(node.op)->isBeforeInBlock(mmaOp))
11781178
defs.push_back(defOp);
11791179
else
11801180
nextOperandDefs.emplace_back(defOp, defPartition);

lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,6 @@ static bool isTMACompatibleEncoding(Attribute enc) {
3838
if (auto nvmma = dyn_cast<ttg::NVMMASharedEncodingAttr>(enc)) {
3939
return !nvmma.getTransposed();
4040
}
41-
if (auto swizzled = dyn_cast<ttg::SwizzledSharedEncodingAttr>(enc)) {
42-
return swizzled.getVec() == 1 && swizzled.getPerPhase() == 1 &&
43-
swizzled.getMaxPhase() == 1;
44-
}
4541
return false;
4642
}
4743

python/test/unit/runtime/test_compilation_listener.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ def cumsum_kernel(ptr):
1818

1919

2020
def test_compile_stats(device: str, fresh_knobs_except_libraries: Any, fresh_triton_cache: str) -> None:
21-
captured: Union[tuple[Union[ASTSource, IRSource], dict[str, Any], CompileTimes, bool], None] = None
21+
captured: Union[tuple[Union[ASTSource, IRSource], dict[str, Any], dict[str, Any], CompileTimes, bool], None] = None
2222

23-
def compile_listener(src: Union[ASTSource, IRSource], metadata: dict[str, Any], times: CompileTimes,
24-
cache_hit: bool) -> None:
23+
def compile_listener(src: Union[ASTSource, IRSource], metadata: dict[str, Any], metadata_group: dict[str, Any],
24+
times: CompileTimes, cache_hit: bool) -> None:
2525
nonlocal captured
2626
assert captured is None
27-
captured = (src, metadata, times, cache_hit)
27+
captured = (src, metadata, metadata_group, times, cache_hit)
2828

2929
fresh_knobs_except_libraries.compilation.listener = compile_listener
3030

@@ -34,17 +34,17 @@ def compile_listener(src: Union[ASTSource, IRSource], metadata: dict[str, Any],
3434
assert captured is not None
3535

3636
# No cache hit at first
37-
assert not captured[3]
37+
assert not captured[4]
3838

3939
# Expected metadata
4040
assert len(captured[1]["hash"]) > 0
4141
assert isinstance(captured[1]["target"], GPUTarget)
4242

4343
# It in fact did take some time to do compilation
44-
assert captured[2].ir_initialization > 0
45-
assert captured[2].total_lowering > 0
46-
assert captured[2].store_results > 0
47-
assert captured[2].total > 0
44+
assert captured[3].ir_initialization > 0
45+
assert captured[3].total_lowering > 0
46+
assert captured[3].store_results > 0
47+
assert captured[3].total > 0
4848

4949
# Now lets create a new instance of the same kernel to pick up cache_hit=True
5050
cumsum_kernel.device_caches.clear()
@@ -53,14 +53,14 @@ def compile_listener(src: Union[ASTSource, IRSource], metadata: dict[str, Any],
5353

5454
assert captured is not None
5555
# Cache hit!
56-
assert captured[3]
56+
assert captured[4]
5757

5858
# Expected metadata
5959
assert len(captured[1]["hash"]) > 0
6060
assert isinstance(captured[1]["target"], GPUTarget)
6161

6262
# It in fact did take some time to do compilation
63-
assert captured[2].ir_initialization > 0
64-
assert captured[2].total_lowering == 0
65-
assert captured[2].store_results == 0
66-
assert captured[2].total > 0
63+
assert captured[3].ir_initialization > 0
64+
assert captured[3].total_lowering == 0
65+
assert captured[3].store_results == 0
66+
assert captured[3].total > 0

python/triton/compiler/compiler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def compile(src, target=None, options=None):
305305
compilation_listener(
306306
src=src,
307307
metadata=res.metadata._asdict(),
308+
metadata_group=metadata_group,
308309
times=timer.end(),
309310
cache_hit=True,
310311
)
@@ -384,7 +385,8 @@ def compile(src, target=None, options=None):
384385

385386
# notify any listener
386387
if compilation_listener:
387-
compilation_listener(src=src, metadata=metadata, times=timer.end(), cache_hit=False)
388+
compilation_listener(src=src, metadata=metadata, metadata_group=metadata_group, times=timer.end(),
389+
cache_hit=False)
388390
# return handle to compiled kernel
389391
return CompiledKernel(src, metadata_group, hash)
390392

python/triton/knobs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,8 @@ def total(self) -> int:
306306

307307
class CompilationListener(Protocol):
308308

309-
def __call__(self, *, src: Union[ASTSource, IRSource], metadata: dict[str, Any], times: CompileTimes,
310-
cache_hit: bool) -> None:
309+
def __call__(self, *, src: Union[ASTSource, IRSource], metadata: dict[str, Any], metadata_group: dict[str, Any],
310+
times: CompileTimes, cache_hit: bool) -> None:
311311
...
312312

313313

test/TritonGPU/load-mma-specialization.mlir

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
#acc_layout = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
66
#oper_layout = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
7+
#oper_layout_trans = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}>
78
// CHECK-DAG: [[SHARED:#.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
89
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
910
#shared_trans = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>
11+
#nvmma_smem = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>
1012
#smem = #ttg.shared_memory
1113
// CHECK-DAG: [[ACC_TMEM:#.*]] = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
1214
#acc_tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
@@ -791,8 +793,8 @@ tt.func @matmul_scaled_rhs_scales_tma(
791793
%k_tiles: i32,
792794
%off_m: i32,
793795
%off_n: i32,
794-
%a_desc: !tt.tensordesc<tensor<128x64xf8E4M3FN, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>>>,
795-
%b_desc: !tt.tensordesc<tensor<128x64xf8E4M3FN, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>>>,
796+
%a_desc: !tt.tensordesc<tensor<128x64xf8E4M3FN, #nvmma_smem>>,
797+
%b_desc: !tt.tensordesc<tensor<128x64xf8E4M3FN, #nvmma_smem>>,
796798
%b_scale_desc: !tt.tensordesc<tensor<128x8xi8, #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 3, 2, 1, 0]}>>>
797799
) {
798800
%true = arith.constant true
@@ -814,14 +816,14 @@ tt.func @matmul_scaled_rhs_scales_tma(
814816

815817
// CHECK: ttng.wait_barrier
816818
// CHECK-COUNT-3: async_tma_copy_global_to_local {{.*}} {ttg.partition = 2 : i32}
817-
%a_reg = tt.descriptor_load %a_desc[%off_m, %off_k] : !tt.tensordesc<tensor<128x64xf8E4M3FN, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>>> -> tensor<128x64xf8E4M3FN, #oper_layout>
818-
%b_reg = tt.descriptor_load %b_desc[%off_n, %off_k] : !tt.tensordesc<tensor<128x64xf8E4M3FN, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>>> -> tensor<128x64xf8E4M3FN, #oper_layout>
819+
%a_reg = tt.descriptor_load %a_desc[%off_m, %off_k] : !tt.tensordesc<tensor<128x64xf8E4M3FN, #nvmma_smem>> -> tensor<128x64xf8E4M3FN, #oper_layout>
820+
%b_reg = tt.descriptor_load %b_desc[%off_n, %off_k] : !tt.tensordesc<tensor<128x64xf8E4M3FN, #nvmma_smem>> -> tensor<128x64xf8E4M3FN, #oper_layout>
819821
%b_scales_reg = tt.descriptor_load %b_scale_desc[%off_m, %c0_i32] : !tt.tensordesc<tensor<128x8xi8, #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 3, 2, 1, 0]}>>> -> tensor<128x8xi8, #oper_layout>
820822

821-
%a_sh = ttg.local_alloc %a_reg : (tensor<128x64xf8E4M3FN, #oper_layout>) -> !ttg.memdesc<128x64xf8E4M3FN, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>, #smem>
822-
%b_sh_raw = ttg.local_alloc %b_reg : (tensor<128x64xf8E4M3FN, #oper_layout>) -> !ttg.memdesc<128x64xf8E4M3FN, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>, #smem>
823+
%a_sh = ttg.local_alloc %a_reg : (tensor<128x64xf8E4M3FN, #oper_layout>) -> !ttg.memdesc<128x64xf8E4M3FN, #nvmma_smem, #smem>
824+
%b_sh_raw = ttg.local_alloc %b_reg : (tensor<128x64xf8E4M3FN, #oper_layout>) -> !ttg.memdesc<128x64xf8E4M3FN, #nvmma_smem, #smem>
823825
// CHECK-NEXT: memdesc_trans {{.*}} ttg.partition = 1 : i32
824-
%b_sh = ttg.memdesc_trans %b_sh_raw {order = array<i32: 1, 0>} : !ttg.memdesc<128x64xf8E4M3FN, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>, #smem> -> !ttg.memdesc<64x128xf8E4M3FN, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}>, #smem>
826+
%b_sh = ttg.memdesc_trans %b_sh_raw {order = array<i32: 1, 0>} : !ttg.memdesc<128x64xf8E4M3FN, #nvmma_smem, #smem> -> !ttg.memdesc<64x128xf8E4M3FN, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}>, #smem>
825827

826828
// CHECK-NEXT: wait_barrier {{.*}} {ttg.partition = 1 : i32}
827829

@@ -831,7 +833,7 @@ tt.func @matmul_scaled_rhs_scales_tma(
831833

832834
// CHECK-NEXT: [[IS_LAST:%.*]] = arith.cmpi eq, %arg6, [[LAST_ITER]]
833835
// CHECK-NEXT: tc_gen5_mma_scaled {{.*}} {ttg.partition = 1 : i32}
834-
%mma_tok = ttng.tc_gen5_mma_scaled %a_sh, %b_sh, %c_tmem[%c_tok], %a_scales_tmem, %b_scales_tmem, %true, %true lhs = e4m3 rhs = e4m3 : !ttg.memdesc<128x64xf8E4M3FN, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>, #smem>, !ttg.memdesc<64x128xf8E4M3FN, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}>, #smem>, !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory>, !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory>
836+
%mma_tok = ttng.tc_gen5_mma_scaled %a_sh, %b_sh, %c_tmem[%c_tok], %a_scales_tmem, %b_scales_tmem, %true, %true lhs = e4m3 rhs = e4m3 : !ttg.memdesc<128x64xf8E4M3FN, #nvmma_smem, #smem>, !ttg.memdesc<64x128xf8E4M3FN, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}>, #smem>, !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory>, !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory>
835837

836838
%c, %load_tok = ttng.tmem_load %c_tmem[%mma_tok] : !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #acc_layout>
837839
scf.yield %c : tensor<128x128xf32, #acc_layout>
@@ -1125,6 +1127,56 @@ tt.func @specialize_mma_only(%rhs_desc: !tt.tensordesc<tensor<64x128xf16, #share
11251127
tt.return
11261128
}
11271129

1130+
// CHECK-LABEL: @load_scale_mma_user
1131+
tt.func @load_scale_mma_user(
1132+
%lhs: !ttg.memdesc<128x64xf16, #shared, #smem>,
1133+
%rhs: !ttg.memdesc<64x128xf16, #shared, #smem>,
1134+
%scales_desc: !tt.tensordesc<tensor<8x128xi8, #shared>>,
1135+
%b_scales: !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory>,
1136+
%ub: i32
1137+
) {
1138+
%c0_i32 = arith.constant 0 : i32
1139+
%c1_i32 = arith.constant 1 : i32
1140+
%true = arith.constant true
1141+
%zero = arith.constant dense<0.0> : tensor<128x128xf32, #acc_layout>
1142+
1143+
// CHECK: scf.for
1144+
%out = scf.for %i = %c0_i32 to %ub step %c1_i32 iter_args(%acc = %zero) -> tensor<128x128xf32, #acc_layout> : i32 {
1145+
// CHECK: wait_barrier [[EMPTY_BAR:%.*]], %{{.*}}partition = 2
1146+
// CHECK: barrier_expect [[SCALES_BAR:%.*]], 1024 {{.*}}partition = 2
1147+
// CHECK: async_tma_copy_global_to_local {{.*}}partition = 2
1148+
%scales_result = tt.descriptor_load %scales_desc[%i, %i] : !tt.tensordesc<tensor<8x128xi8, #shared>> -> tensor<8x128xi8, #oper_layout>
1149+
%scales_shared = ttg.local_alloc %scales_result : (tensor<8x128xi8, #oper_layout>) -> !ttg.memdesc<8x128xi8, #shared, #smem>
1150+
// CHECK: wait_barrier [[SCALES_BAR]]{{.*}}partition = 0
1151+
// CHECK-NEXT: [[SCALES_REG:%.*]] = ttg.local_load {{.*}}partition = 0
1152+
// CHECK-NEXT: arrive_barrier [[EMPTY_BAR]]{{.*}}partition = 0
1153+
%scales_reg = ttg.local_load %scales_shared : !ttg.memdesc<8x128xi8, #shared, #smem> -> tensor<8x128xi8, #oper_layout>
1154+
// CHECK-NEXT: [[SCALES_TRANS:%.*]] = tt.trans [[SCALES_REG]] {{.*}}partition = 0
1155+
%scales_T = tt.trans %scales_reg {order = array<i32: 1, 0>} : tensor<8x128xi8, #oper_layout> -> tensor<128x8xi8, #oper_layout_trans>
1156+
// CHECK-NEXT: wait_barrier [[SCALES_TMEM_BAR:%.*]], %arg{{[0-9]+}} {{.*}}partition = 0
1157+
// CHECK-NEXT: tmem_store [[SCALES_TRANS]], [[SCALES_TMEM:%.*]], %true {{.*}}partition = 0
1158+
%scales_tmem = ttng.tmem_alloc %scales_T : (tensor<128x8xi8, #oper_layout_trans>) -> !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory>
1159+
// CHECK-NEXT: arrive_barrier [[SCALES_READY_BAR:%.*]], 1 {{.*}}partition = 0
1160+
1161+
// CHECK: wait_barrier [[SCALES_READY_BAR]]{{.*}}partition = 1
1162+
%acc_tmem, %acc_tok = ttng.tmem_alloc %acc : (tensor<128x128xf32, #acc_layout>) -> (!ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
1163+
// CHECK-NEXT: tc_gen5_mma_scaled {{.*}} [[SCALES_TMEM]]{{.*}} [[USER_BAR:%.*]][%true], [[SCALES_TMEM_BAR]][%true] {{.*}}partition = 1
1164+
%mma_tok = ttng.tc_gen5_mma_scaled %lhs, %rhs, %acc_tmem[%acc_tok], %scales_tmem, %b_scales, %true, %true lhs = e4m3 rhs = e4m3 : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory>, !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory>
1165+
1166+
// CHECK: wait_barrier [[USER_BAR]]{{.*}}partition = 0
1167+
// CHECK-NEXT: tmem_load
1168+
%c, %load_tok = ttng.tmem_load %acc_tmem[%mma_tok] : !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #acc_layout>
1169+
// CHECK: arrive_barrier [[USER_DONE:%.*]], 1 {{.*}}partition = 0
1170+
// CHECK: wait_barrier [[USER_DONE]]{{.*}}partition = 1
1171+
1172+
"user"(%c) : (tensor<128x128xf32, #acc_layout>) -> ()
1173+
1174+
scf.yield %c : tensor<128x128xf32, #acc_layout>
1175+
} {tt.warp_specialize, tt.num_stages = 3 : i32}
1176+
"use"(%out) : (tensor<128x128xf32, #acc_layout>) -> ()
1177+
tt.return
1178+
}
1179+
11281180
}
11291181

11301182
// -----

test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ tt.func public @tma_scatter(%arg0: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %ar
4949

5050
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
5151
// CHECK-DAG: #[[BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
52-
// CHECK-DAG: #[[SWIZZLE_3D:.*]] = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 1, 0]}>
52+
// CHECK-DAG: #[[SWIZZLE_MMA:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32, CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [2, 1, 0]}>
5353
// CHECK-DAG: #[[SWIZZLE_2D:.*]] = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
5454
tt.func public @tma_scatter(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: i32, %arg3: i64, %arg4: i64) {
55-
// CHECK: tt.make_tensor_descriptor {{.*}} : <f32>, <tensor<1x256x32xf32, #[[SWIZZLE_3D]]>>
56-
// CHECK: %[[LOAD:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc<tensor<1x256x32xf32, #[[SWIZZLE_3D]]>> -> tensor<256x32xf32, #[[BLOCKED]]>
55+
// CHECK: tt.make_tensor_descriptor {{.*}} : <f32>, <tensor<1x256x32xf32, #[[SWIZZLE_MMA]]>>
56+
// CHECK: %[[LOAD:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc<tensor<1x256x32xf32, #[[SWIZZLE_MMA]]>> -> tensor<256x32xf32, #[[BLOCKED]]>
5757
// CHECK: ttg.local_alloc %[[LOAD]] : (tensor<256x32xf32, #[[BLOCKED]]>) -> !ttg.memdesc<256x32xf32, #[[SWIZZLE_2D]], #smem>
5858
%c1_i32 = arith.constant 1 : i32
5959
%c1_i64 = arith.constant 1 : i64

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertWarpSpecializeToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
436436
for (auto [ws, stateMap] : llvm::zip(wsOps, warpToState)) {
437437
Block *before = ws->getBlock();
438438
Block *after = b.splitBlock(before, ws->getIterator());
439-
b.setInsertionPointToEnd(before);
439+
TritonLLVMIRRewriter b(ws.getLoc(), OpBuilder::atBlockEnd(before));
440440
Value statePtr = LLVM::getSharedMemoryBase(b.getLoc(), b, targetInfo, func);
441441
for (auto [i, state] : llvm::enumerate(stateMap)) {
442442
Value stateVal = b.i8_val(state);
@@ -469,7 +469,7 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
469469
b.create<LLVM::BrOp>(&ws.getDefaultRegion().front());
470470

471471
ws.getDefaultRegion().walk([&, ws = ws](WarpYieldOp op) mutable {
472-
b.setInsertionPoint(op);
472+
TritonLLVMIRRewriter b(op.getLoc(), op);
473473
createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt,
474474
/*aligned=*/false);
475475
if (auto actRegs = ws.getActualRegisters())

0 commit comments

Comments
 (0)