Skip to content

Commit 8e98739

Browse files
Merge commit 'd13a0404d4248117ee8ac2129235e853c74f8c53'
2 parents 023102d + d13a040 commit 8e98739

File tree

18 files changed

+412
-65
lines changed

18 files changed

+412
-65
lines changed

lib/Conversion/TritonToTritonGPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_triton_library(TritonToTritonGPU
1010
MLIRPass
1111
MLIRTransforms
1212
TritonIR
13+
ProtonIR
1314
TritonGPUIR
1415
TritonGPUTransforms
1516
)

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include "triton/Conversion/TritonToTritonGPU/Passes.h.inc"
1717
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1818

19+
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
20+
1921
namespace {
2022

2123
using namespace mlir;
@@ -555,7 +557,17 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
555557
GenericOpPattern<triton::DotScaledOp>, GenericOpPattern<triton::CallOp>,
556558
TritonFuncOpPattern>(typeConverter, context);
557559
}
558-
560+
// Proton patterns
561+
// NOTE: Because Proton's inputs are scalars and not tensors this conversion
562+
// isn't strictly nessessary however you could envision a case where we pass in
563+
// tensors in for Triton object specific tracing operations in which case we
564+
// would need to fill in the OpConversionPattern
565+
void populateProtonPatterns(TritonGPUTypeConverter &typeConverter,
566+
RewritePatternSet &patterns) {
567+
MLIRContext *context = patterns.getContext();
568+
patterns.add<GenericOpPattern<triton::proton::RecordOp>>(typeConverter,
569+
context);
570+
}
559571
//
560572
// SCF patterns
561573
//
@@ -770,6 +782,7 @@ class ConvertTritonToTritonGPU
770782
populateArithPatternsAndLegality(typeConverter, patterns, target);
771783
populateMathPatternsAndLegality(typeConverter, patterns, target);
772784
populateTritonPatterns(typeConverter, patterns, numCTAs);
785+
populateProtonPatterns(typeConverter, patterns);
773786
// TODO: can we use
774787
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
775788
populateSCFPatterns(typeConverter, patterns);

lib/Tools/LinearLayout.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,6 @@ std::unique_ptr<uint64_t[]> getMatrix(const LinearLayout &layout) {
117117
// outDim as columns. In other words, finds the number of linearly-independent
118118
// bases for this output dimension.
119119
int getMatrixRank(std::unique_ptr<uint64_t[]> m, int numRows, int numCols) {
120-
// f2reduce underflows if the number of cols is 0, return the rank early in
121-
// this case.
122-
if (numCols == 0) {
123-
return 0;
124-
}
125120
// stride is specified in number of 64-bit words per row, and we pack our
126121
// matrix so that there's only one uint64_t per row.
127122
assert(numCols <= 64);

python/src/ir.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
#include "triton/Tools/Sys/GetEnv.hpp"
3232
#include "llvm/Support/SourceMgr.h"
3333

34+
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
35+
3436
namespace {
3537

3638
namespace py = pybind11;
@@ -235,7 +237,8 @@ void init_triton_ir(py::module &&m) {
235237
registry.insert<TritonDialect, ::mlir::triton::gpu::TritonGPUDialect,
236238
math::MathDialect, arith::ArithDialect, scf::SCFDialect,
237239
::mlir::gpu::GPUDialect, cf::ControlFlowDialect,
238-
LLVM::LLVMDialect, mlir::ub::UBDialect>();
240+
::mlir::triton::proton::ProtonDialect, LLVM::LLVMDialect,
241+
mlir::ub::UBDialect>();
239242
mlir::LLVM::registerInlinerInterface(registry);
240243
registerBuiltinDialectTranslation(registry);
241244
registerLLVMDialectTranslation(registry);
@@ -1654,6 +1657,11 @@ void init_triton_ir(py::module &&m) {
16541657
std::vector<int32_t> &tensorShape) -> Value {
16551658
return self.create<MakeTensorDescOp>(base, shape, strides,
16561659
tensorShape);
1660+
})
1661+
// Proton Ops
1662+
.def("create_proton_record",
1663+
[](TritonOpBuilder &self, bool isStart, int32_t regionId) -> void {
1664+
self.create<mlir::triton::proton::RecordOp>(isStart, regionId);
16571665
});
16581666

16591667
py::class_<PassManager>(m, "pass_manager", py::module_local())

python/test/unit/language/test_core.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ def xpu_has_fp64():
6161
return target.arch['has_fp64']
6262

6363

64+
# No need to emulate NumPy 2.0 if the user has NumPy 2.0
65+
if np.__version__[0] != "1":
66+
promotion_numpy_2_0 = contextlib.nullcontext
67+
6468
# TODO: enable multiple cta cluster testing.
6569
# num_ctas_list = [1, 4] if torch.cuda.get_device_capability()[0] == 9 else [1]
6670
num_ctas_list = [1]
@@ -1698,7 +1702,7 @@ def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr):
16981702

16991703

17001704
@pytest.mark.interpreter
1701-
@pytest.mark.skipif((is_cuda() and torch.cuda.get_device_capability()[0] < 9) or is_hip(),
1705+
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9,
17021706
reason="Requires compute capability >= 9 for NV")
17031707
def test_load_scope_sem_coop_grid_cta_not_one(device):
17041708

test/TritonGPU/amd/amd-block-pingpong.mlir

Lines changed: 110 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,43 +86,43 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
8686
// CHECK: tt.load
8787
// CHECK: %[[SLICEA0:.+]] = ttg.local_load
8888
// CHECK: %[[SLICEB0:.+]] = ttg.local_load
89-
// CHECK: rocdl.sched.barrier 0
9089
// CHECK: gpu.barrier
90+
// CHECK: rocdl.sched.barrier 0
9191
// CHECK: rocdl.s.setprio 1
9292
// CHECK: %[[DOT0:.+]] = tt.dot %[[SLICEA0]], %[[SLICEB0]]
9393
// CHECK: rocdl.s.setprio 0
94-
// CHECK: rocdl.sched.barrier 0
9594
// CHECK: gpu.barrier
95+
// CHECK: rocdl.sched.barrier 0
9696
// CHECK: tt.load
9797
// CHECK: %[[SLICEA1:.+]] = ttg.local_load
9898
// CHECK: %[[SLICEB1:.+]] = ttg.local_load
99-
// CHECK: rocdl.sched.barrier 0
10099
// CHECK: gpu.barrier
100+
// CHECK: rocdl.sched.barrier 0
101101
// CHECK: rocdl.s.setprio 1
102102
// CHECK: %[[DOT1:.+]] = tt.dot %[[SLICEA1]], %[[SLICEB1]], %[[DOT0]]
103103
// CHECK: rocdl.s.setprio 0
104-
// CHECK: rocdl.sched.barrier 0
105104
// CHECK: gpu.barrier
105+
// CHECK: rocdl.sched.barrier 0
106106
// CHECK: %[[SLICEA2:.+]] = ttg.local_load
107107
// CHECK: %[[SLICEB2:.+]] = ttg.local_load
108108
// CHECK: %[[SLICEA3:.+]] = ttg.local_load
109109
// CHECK: %[[SLICEB3:.+]] = ttg.local_load
110-
// CHECK: rocdl.sched.barrier 0
111110
// CHECK: gpu.barrier
111+
// CHECK: rocdl.sched.barrier 0
112112
// CHECK: rocdl.s.setprio 1
113113
// CHECK: %[[DOT2:.+]] = tt.dot %[[SLICEA2]], %[[SLICEB2]], %[[DOT1]]
114114
// CHECK: rocdl.s.setprio 0
115-
// CHECK: rocdl.sched.barrier 0
116115
// CHECK: gpu.barrier
116+
// CHECK: rocdl.sched.barrier 0
117117
// CHECK: ttg.local_store
118118
// CHECK: ttg.local_store
119-
// CHECK: rocdl.sched.barrier 0
120119
// CHECK: gpu.barrier
120+
// CHECK: rocdl.sched.barrier 0
121121
// CHECK: rocdl.s.setprio 1
122122
// CHECK: tt.dot %[[SLICEA3]], %[[SLICEB3]], %[[DOT2]]
123123
// CHECK: rocdl.s.setprio 0
124-
// CHECK: rocdl.sched.barrier 0
125124
// CHECK: gpu.barrier
125+
// CHECK: rocdl.sched.barrier 0
126126
// CHECK: scf.yield
127127
// CHECK: amdgpu.cond_barrier %[[WARPLOW]]
128128

@@ -169,9 +169,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
169169
%27 = tt.load %26 : tensor<256x64x!tt.ptr<f16>, #blocked1>
170170
%28 = tt.addptr %arg13, %cst_0 : tensor<64x256x!tt.ptr<f16>, #blocked>, tensor<64x256xi32, #blocked>
171171
%29 = tt.load %28 : tensor<64x256x!tt.ptr<f16>, #blocked>
172-
%30 = ttg.local_load %arg15 : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
173-
%31 = ttg.local_load %arg16 : !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
174-
%32 = tt.dot %30, %31, %arg11 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<256x256xf32, #mma>
172+
%30 = ttg.local_load %arg15 : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
173+
%31 = ttg.local_load %arg16 : !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
174+
%32 = tt.dot %30, %31, %arg11 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x256xf32, #mma>
175175
%33 = arith.addi %arg14, %c1_i32 : i32
176176
%34 = arith.cmpi slt, %33, %c1_i32 : i32
177177
%35 = arith.select %34, %33, %c0_i32 : i32
@@ -189,6 +189,105 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
189189

190190
// -----
191191

192+
// CHECK: gpu.barrier
193+
// CHECK: %[[IDX:.+]] = rocdl.workitem.id.x
194+
// CHECK: %[[XDIV:.+]] = arith.divsi %[[IDX]]
195+
// CHECK: %[[WARPLOW:.+]] = arith.cmpi eq, %[[XDIV]]
196+
// CHECK: %[[WARPHIGH:.+]] = arith.cmpi ne, %[[XDIV]]
197+
// CHECK: amdgpu.cond_barrier %[[WARPHIGH]]
198+
// CHECK: scf.for
199+
200+
// CHECK: %[[SLICEA0:.+]] = ttg.local_load
201+
// CHECK: %[[SLICEB0:.+]] = ttg.local_load
202+
// CHECK: rocdl.sched.barrier 0
203+
// CHECK: tt.load
204+
// CHECK: rocdl.sched.barrier 0
205+
// CHECK: %[[SLICEA1:.+]] = ttg.local_load
206+
// CHECK: %[[SLICEB1:.+]] = ttg.local_load
207+
// CHECK: rocdl.sched.barrier 0
208+
// CHECK: tt.load
209+
// CHECK: rocdl.s.barrier
210+
// CHECK: rocdl.sched.barrier 0
211+
// CHECK: rocdl.s.setprio 1
212+
// CHECK: %[[DOT0:.+]] = tt.dot %[[SLICEA0]], %[[SLICEB0]]
213+
// CHECK: rocdl.s.setprio 0
214+
// CHECK: gpu.barrier
215+
// CHECK: rocdl.sched.barrier 0
216+
// CHECK: ttg.local_store
217+
// CHECK: ttg.local_store
218+
// CHECK: gpu.barrier
219+
// CHECK: rocdl.sched.barrier 0
220+
// CHECK: rocdl.s.setprio 1
221+
// CHECK: %[[DOT1:.+]] = tt.dot %[[SLICEA1]], %[[SLICEB1]], %[[DOT0]]
222+
// CHECK: rocdl.s.setprio 0
223+
// CHECK: gpu.barrier
224+
// CHECK: rocdl.sched.barrier 0
225+
// CHECK: scf.yield
226+
// CHECK: amdgpu.cond_barrier %[[WARPLOW]]
227+
228+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
229+
#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>
230+
#loc = loc("/home/jung/rocm/triton/python/perf-kernels/tools/tune_gemm/matmul_kernel.py":6:0)
231+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}>
232+
#shared1 = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
233+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
234+
tt.func public @pingpong_medium(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
235+
%cst = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma>
236+
%c1_i32 = arith.constant 1 : i32
237+
%cst_0 = arith.constant dense<64> : tensor<64x128xi32, #blocked>
238+
%cst_1 = arith.constant dense<64> : tensor<256x64xi32, #blocked1>
239+
%c0_i32 = arith.constant 0 : i32
240+
%c64_i32 = arith.constant 64 : i32
241+
%0 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x1x!tt.ptr<f16>, #blocked1>
242+
%1 = tt.get_program_id x : i32
243+
%2 = tt.splat %1 : i32 -> tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
244+
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
245+
%4 = arith.addi %2, %3 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
246+
%5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1>
247+
%6 = tt.splat %arg6 : i32 -> tensor<256x1xi32, #blocked1>
248+
%7 = arith.muli %5, %6 : tensor<256x1xi32, #blocked1>
249+
%8 = tt.addptr %0, %7 : tensor<256x1x!tt.ptr<f16>, #blocked1>, tensor<256x1xi32, #blocked1>
250+
%9 = tt.broadcast %8 : tensor<256x1x!tt.ptr<f16>, #blocked1> -> tensor<256x64x!tt.ptr<f16>, #blocked1>
251+
%10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
252+
%11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1>
253+
%12 = tt.broadcast %11 : tensor<1x64xi32, #blocked1> -> tensor<256x64xi32, #blocked1>
254+
%13 = tt.addptr %9, %12 : tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<256x64xi32, #blocked1>
255+
%14 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>, #blocked>
256+
%15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
257+
%16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
258+
%17 = tt.addptr %14, %16 : tensor<64x1x!tt.ptr<f16>, #blocked>, tensor<64x1xi32, #blocked>
259+
%18 = tt.broadcast %17 : tensor<64x1x!tt.ptr<f16>, #blocked> -> tensor<64x128x!tt.ptr<f16>, #blocked>
260+
%19 = tt.splat %arg7 : i32 -> tensor<64x128xi32, #blocked>
261+
%20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr<f16>, #blocked>, tensor<64x128xi32, #blocked>
262+
%21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable>
263+
%22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable>
264+
%23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>
265+
%24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>
266+
%25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg11 = %cst, %arg12 = %13, %arg13 = %20, %arg14 = %c0_i32, %arg15 = %23, %arg16 = %24) -> (tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<64x128x!tt.ptr<f16>, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>) : i32 {
267+
%26 = tt.addptr %arg12, %cst_1 : tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<256x64xi32, #blocked1>
268+
%27 = tt.load %26 : tensor<256x64x!tt.ptr<f16>, #blocked1>
269+
%28 = tt.addptr %arg13, %cst_0 : tensor<64x128x!tt.ptr<f16>, #blocked>, tensor<64x128xi32, #blocked>
270+
%29 = tt.load %28 : tensor<64x128x!tt.ptr<f16>, #blocked>
271+
%30 = ttg.local_load %arg15 : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
272+
%31 = ttg.local_load %arg16 : !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
273+
%32 = tt.dot %30, %31, %arg11 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x128xf32, #mma>
274+
%33 = arith.addi %arg14, %c1_i32 : i32
275+
%34 = arith.cmpi slt, %33, %c1_i32 : i32
276+
%35 = arith.select %34, %33, %c0_i32 : i32
277+
%36 = ttg.memdesc_subview %21[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>
278+
ttg.local_store %27, %36 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>
279+
%37 = ttg.memdesc_subview %22[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>
280+
ttg.local_store %29, %37 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>
281+
scf.yield %32, %26, %28, %35, %36, %37 : tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<64x128x!tt.ptr<f16>, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>
282+
}
283+
ttg.local_dealloc %21 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable>
284+
ttg.local_dealloc %22 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable>
285+
tt.return
286+
}
287+
}
288+
289+
// -----
290+
192291
// CHECK-LABEL: pingpong_reject
193292
// CHECK-COUNT-2: local_load
194293
// CHECK-NOT: local_load

third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,5 @@ add_triton_library(TritonAMDGPUToLLVM
2929
LINK_LIBS PUBLIC
3030
TritonGPUToLLVM
3131
TritonAMDGPUIR
32+
TritonProtonToLLVM
3233
)

third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
2525
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
2626

27+
#include "third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h"
28+
2729
namespace mlir::triton {
2830
#define GEN_PASS_DEF_CONVERTTRITONAMDGPUTOLLVM
2931
#include "TritonAMDGPUToLLVM/Passes.h.inc"
@@ -228,6 +230,10 @@ struct ConvertTritonAMDGPUToLLVM
228230
patterns);
229231
mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns,
230232
targetInfo, commonBenefit);
233+
234+
mlir::triton::proton::populateRecordOpToLLVMPattern(
235+
typeConverter, patterns, targetInfo, commonBenefit);
236+
231237
mlir::ub::populateUBToLLVMConversionPatterns(typeConverter, patterns);
232238

233239
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) {

0 commit comments

Comments
 (0)