Skip to content

Commit 0c3aa91

Browse files
Merge commit 'c2c193a9059707303db4650ad4b8aee03608e921'
2 parents 262102c + c2c193a commit 0c3aa91

File tree

11 files changed

+139
-863
lines changed

11 files changed

+139
-863
lines changed

lib/Dialect/Triton/Transforms/Combine.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,29 @@ class CombineBroadcastMulReducePattern : public RewritePattern {
187187
}
188188
};
189189

190+
// When reducing a 1D tensor the order of elements of the tensor doesn't matter.
191+
// Therefore we can relax the reshape to allow it to re-order elements.
192+
class CombineReshapeReducePatterns : public mlir::OpRewritePattern<ReshapeOp> {
193+
public:
194+
using OpRewritePattern::OpRewritePattern;
195+
196+
mlir::LogicalResult
197+
matchAndRewrite(triton::ReshapeOp reshapeOp,
198+
mlir::PatternRewriter &rewriter) const override {
199+
if (reshapeOp.getAllowReorder())
200+
return failure();
201+
if (reshapeOp.getType().getRank() != 1)
202+
return failure();
203+
for (Operation *user : reshapeOp->getUsers()) {
204+
if (!isa<triton::ReduceOp, triton::HistogramOp>(user))
205+
return failure();
206+
}
207+
rewriter.modifyOpInPlace(reshapeOp,
208+
[&]() { reshapeOp.setAllowReorder(true); });
209+
return success();
210+
}
211+
};
212+
190213
class CombineOpsPass : public TritonCombineOpsBase<CombineOpsPass> {
191214
public:
192215
void runOnOperation() override {
@@ -203,6 +226,7 @@ class CombineOpsPass : public TritonCombineOpsBase<CombineOpsPass> {
203226
patterns.add<CombineSelectMaskedLoadPattern>(context);
204227
patterns.add<CombineAddPtrPattern>(context);
205228
patterns.add<CombineBroadcastMulReducePattern>(context);
229+
patterns.add<CombineReshapeReducePatterns>(context);
206230

207231
if (applyPatternsGreedily(m, std::move(patterns)).failed())
208232
signalPassFailure();

python/test/unit/language/test_standard.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,22 @@ def triton_flip_kernel(out_ptr, x_ptr, N: tl.constexpr):
9898
torch.testing.assert_close(expect, actual)
9999

100100

101+
@pytest.mark.interpreter
102+
def test_ravel(device):
103+
104+
@triton.jit
105+
def triton_ravel(out_ptr):
106+
a = tl.arange(0, 256)
107+
a = tl.reshape(a, (32, 8))
108+
a = tl.ravel(a)
109+
tl.store(out_ptr + tl.arange(0, 256), a)
110+
111+
out = torch.empty((256, ), device=device, dtype=torch.int32)
112+
triton_ravel[(1, )](out)
113+
114+
assert (out == torch.arange(0, 256, device=device)).all()
115+
116+
101117
@pytest.mark.interpreter
102118
@pytest.mark.parametrize("size_i, size_j, size_g", [[5, 7, 3]])
103119
def test_swizzle2d(size_i, size_j, size_g, device):

python/triton/language/standard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,14 @@ def softmax(x, ieee_rounding=False):
5959

6060
@core._tensor_member_fn
6161
@jit
62-
def ravel(x):
62+
def ravel(x, can_reorder=False):
6363
"""
6464
Returns a contiguous flattened view of :code:`x`.
6565
6666
:param x: the input tensor
6767
:type x: Block
6868
"""
69-
return core.reshape(x, [x.numel], can_reorder=True)
69+
return core.reshape(x, [x.numel], can_reorder=can_reorder)
7070

7171

7272
@jit

test/Conversion/tma_to_llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ tt.func @tma_scatter(%arg0: !tt.ptr<i8>, %arg1: tensor<32xi32, #ttg.slice<{dim =
168168
// CHECK-SAME: (i1 [[PRED]], ptr addrspace(1) %0, i32 %2, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, ptr addrspace(3) [[PTR]])
169169
ttng.async_tma_scatter %arg0[%arg1, %arg2] %arg3 : !tt.ptr<i8>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>
170170

171-
// CHECK: call void @llvm.nvvm.cp.async.commit.group()
171+
// CHECK: nvvm.cp.async.bulk.commit.group()
172172

173173
// CHECK-NEXT: ret void
174174
tt.return

test/Triton/combine.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,3 +345,16 @@ tt.func @test_nested_transpose(%arg0: tensor<2x4x8xf32>) -> (tensor<8x2x4xf32>)
345345
// CHECK: tt.return %[[res]]
346346
tt.return %b : tensor<8x2x4xf32>
347347
}
348+
349+
// CHECK-LABEL: test_reshape_reduce
350+
tt.func @test_reshape_reduce(%0: tensor<32x4x2xi32>) -> (i32, tensor<16xi32>) {
351+
// CHECK: tt.reshape %{{.+}} allow_reorder : tensor<32x4x2xi32> -> tensor<256xi32>
352+
%1 = tt.reshape %0 : tensor<32x4x2xi32> -> tensor<256xi32>
353+
%2 = "tt.reduce" (%1) ({
354+
^bb0(%arg7: i32, %arg8: i32):
355+
%add = arith.addi %arg7, %arg8 : i32
356+
tt.reduce.return %add : i32
357+
}) {axis = 0 : i32} : (tensor<256xi32>) -> i32
358+
%3 = tt.histogram %1 : tensor<256xi32> -> tensor<16xi32>
359+
tt.return %2, %3 : i32, tensor<16xi32>
360+
}

test/TritonGPU/amd/amd-block-pingpong.mlir

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
128128

129129
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
130130
#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>
131-
#loc = loc("/home/jung/rocm/triton/python/perf-kernels/tools/tune_gemm/matmul_kernel.py":6:0)
132131
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}>
133132
#shared1 = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
134133
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
@@ -227,7 +226,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
227226

228227
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
229228
#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>
230-
#loc = loc("/home/jung/rocm/triton/python/perf-kernels/tools/tune_gemm/matmul_kernel.py":6:0)
231229
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}>
232230
#shared1 = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
233231
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
@@ -288,6 +286,77 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
288286

289287
// -----
290288

289+
// CHECK-LABEL: pingpong_medium_cast
290+
// CHECK-COUNT-2: local_load
291+
// CHECK-NOT: setprio
292+
// CHECK-NOT: barrier
293+
294+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
295+
#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>
296+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}>
297+
#shared1 = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
298+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
299+
tt.func public @pingpong_medium_cast(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
300+
%cst = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma>
301+
%c1_i32 = arith.constant 1 : i32
302+
%cst_0 = arith.constant dense<64> : tensor<64x128xi32, #blocked>
303+
%cst_1 = arith.constant dense<64> : tensor<256x64xi32, #blocked1>
304+
%c0_i32 = arith.constant 0 : i32
305+
%c64_i32 = arith.constant 64 : i32
306+
%0 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x1x!tt.ptr<f16>, #blocked1>
307+
%1 = tt.get_program_id x : i32
308+
%2 = tt.splat %1 : i32 -> tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
309+
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
310+
%4 = arith.addi %2, %3 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
311+
%5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1>
312+
%6 = tt.splat %arg6 : i32 -> tensor<256x1xi32, #blocked1>
313+
%7 = arith.muli %5, %6 : tensor<256x1xi32, #blocked1>
314+
%8 = tt.addptr %0, %7 : tensor<256x1x!tt.ptr<f16>, #blocked1>, tensor<256x1xi32, #blocked1>
315+
%9 = tt.broadcast %8 : tensor<256x1x!tt.ptr<f16>, #blocked1> -> tensor<256x64x!tt.ptr<f16>, #blocked1>
316+
%10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
317+
%11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1>
318+
%12 = tt.broadcast %11 : tensor<1x64xi32, #blocked1> -> tensor<256x64xi32, #blocked1>
319+
%13 = tt.addptr %9, %12 : tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<256x64xi32, #blocked1>
320+
%14 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>, #blocked>
321+
%15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
322+
%16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
323+
%17 = tt.addptr %14, %16 : tensor<64x1x!tt.ptr<f16>, #blocked>, tensor<64x1xi32, #blocked>
324+
%18 = tt.broadcast %17 : tensor<64x1x!tt.ptr<f16>, #blocked> -> tensor<64x128x!tt.ptr<f16>, #blocked>
325+
%19 = tt.splat %arg7 : i32 -> tensor<64x128xi32, #blocked>
326+
%20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr<f16>, #blocked>, tensor<64x128xi32, #blocked>
327+
%21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable>
328+
%22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xi16, #shared1, #ttg.shared_memory, mutable>
329+
%23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>
330+
%24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xi16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable>
331+
%25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg11 = %cst, %arg12 = %13, %arg13 = %20, %arg14 = %c0_i32, %arg15 = %23, %arg16 = %24) -> (tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<64x128x!tt.ptr<f16>, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable>) : i32 {
332+
%26 = tt.addptr %arg12, %cst_1 : tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<256x64xi32, #blocked1>
333+
%27 = tt.load %26 : tensor<256x64x!tt.ptr<f16>, #blocked1>
334+
%28 = tt.addptr %arg13, %cst_0 : tensor<64x128x!tt.ptr<f16>, #blocked>, tensor<64x128xi32, #blocked>
335+
%29 = tt.load %28 : tensor<64x128x!tt.ptr<f16>, #blocked>
336+
%cast2 = tt.bitcast %29 : tensor<64x128xf16, #blocked> -> tensor<64x128xi16, #blocked>
337+
%30 = ttg.local_load %arg15 : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
338+
%31 = ttg.local_load %arg16 : !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x128xi16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
339+
%cast = tt.bitcast %31 : tensor<64x128xi16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
340+
%32 = tt.dot %30, %cast, %arg11 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x128xf32, #mma>
341+
%33 = arith.addi %arg14, %c1_i32 : i32
342+
%34 = arith.cmpi slt, %33, %c1_i32 : i32
343+
%35 = arith.select %34, %33, %c0_i32 : i32
344+
%36 = ttg.memdesc_subview %21[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>
345+
ttg.local_store %27, %36 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>
346+
%37 = ttg.memdesc_subview %22[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xi16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable>
347+
ttg.local_store %cast2, %37 : tensor<64x128xi16, #blocked> -> !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable>
348+
scf.yield %32, %26, %28, %35, %36, %37 : tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<64x128x!tt.ptr<f16>, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable>
349+
}
350+
ttg.local_dealloc %21 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable>
351+
ttg.local_dealloc %22 : !ttg.memdesc<1x64x128xi16, #shared1, #ttg.shared_memory, mutable>
352+
tt.return
353+
}
354+
}
355+
356+
357+
// -----
358+
359+
291360
// CHECK-LABEL: pingpong_reject
292361
// CHECK-COUNT-2: local_load
293362
// CHECK-NOT: local_load
@@ -296,7 +365,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
296365

297366
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
298367
#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>
299-
#loc = loc("/home/jung/rocm/triton/python/perf-kernels/tools/tune_gemm/matmul_kernel.py":6:0)
300368
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}>
301369
#shared1 = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
302370
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {

third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,11 @@ LogicalResult Pingponger::genLocalSlice(OpBuilder &builder, Value v,
151151
int64_t sliceWidth) {
152152
SmallVector<Operation *> slices;
153153
SmallVector<Operation *> subviews;
154-
auto memDesc = v.getDefiningOp()->getOperand(0);
154+
// TODO: support transformed input to dot
155+
auto localLoad = v.getDefiningOp<ttg::LocalLoadOp>();
156+
if (!localLoad)
157+
return failure();
158+
auto memDesc = localLoad.getSrc();
155159
auto type = cast<ttg::MemDescType>(memDesc.getType());
156160
SmallVector<int64_t> shape = llvm::to_vector(type.getShape());
157161
Type elementType = type.getElementType();

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
add_triton_library(TritonNVIDIAGPUToLLVM
2-
ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp
32
ConvertLayoutOpToLLVM.cpp
43
MemoryOpToLLVM.cpp
54
DotOpToLLVM/MMAv2.cpp

0 commit comments

Comments
 (0)