Skip to content

Commit e14f5b9

Browse files
Merge OpenAI Triton commit 9410804 (#5132)
This PR change the Triton base from 625c8cb to 9410804 (Sep 10). Pass rate: 98.8% Please do not squash and merge this PR.
2 parents f631f67 + 10dfb5f commit e14f5b9

File tree

13 files changed

+178
-42
lines changed

13 files changed

+178
-42
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,6 +1158,14 @@ SharedMemoryObject::getMaskSpanOffsets(triton::gpu::MemDescType srcTy) {
11581158
if (allocShape == shape) {
11591159
return 0;
11601160
}
1161+
if (auto paddedEncoding = dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(
1162+
srcTy.getEncoding())) {
1163+
// Mask is used in fusion of constant part of memory operation address as
1164+
// immediate operand. Padded layout has additional address computations
1165+
// between main offset computation and actual memory access, which breaks
1166+
// constand fusing. Full mask disables this optimization.
1167+
return ~uint64_t(0);
1168+
}
11611169
auto totalLl = triton::gpu::toLinearLayout(allocShape, srcTy.getEncoding());
11621170
auto dimNames = standardOutDimNames(ctx, shape.size());
11631171
// Remove the kBlock dimension
@@ -1194,14 +1202,15 @@ Value SharedMemoryObject::getShmemOffset(Location loc, RewriterBase &rewriter,
11941202
return b.i32_val(0);
11951203
}
11961204

1205+
LinearLayout ll;
11971206
// We return the offset without the padding. The padding will be added in the
11981207
// lowering
11991208
if (auto paddedSharedEncoding =
12001209
dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(
12011210
srcTy.getEncoding())) {
1202-
auto allocShape64 = srcTy.getAllocShape();
1203-
SmallVector<unsigned> allocShape(allocShape64.begin(), allocShape64.end());
1204-
return LLVM::linearize(rewriter, loc, offsets, allocShape);
1211+
ll = paddedSharedEncoding.getLinearComponent();
1212+
} else {
1213+
ll = triton::gpu::toLinearLayout(srcTy);
12051214
}
12061215

12071216
auto dimNames = standardOutDimNames(ctx, offsets.size());
@@ -1210,7 +1219,6 @@ Value SharedMemoryObject::getShmemOffset(Location loc, RewriterBase &rewriter,
12101219
logicalOffsets.push_back({dim, offset});
12111220
}
12121221

1213-
LinearLayout ll = triton::gpu::toLinearLayout(srcTy);
12141222
ll = ll.sublayout({str_attr("offset")}, dimNames);
12151223
auto offset =
12161224
applyLinearLayout(loc, rewriter, ll.invert(), logicalOffsets)[0].second;

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1815,6 +1815,12 @@ Attribute PaddedSharedEncodingAttr::parse(AsmParser &parser, Type type) {
18151815
}
18161816
}
18171817

1818+
if (order.size() != shape.size()) {
1819+
parser.emitError(parser.getCurrentLocation(),
1820+
"Mismatch of shape and order ranks in padded layout");
1821+
return {};
1822+
}
1823+
18181824
// Create identity mapping based on shape and order
18191825
auto kOffset = StringAttr::get(parser.getContext(), "offset");
18201826
maybeLL = identityStandardND(kOffset, shape, order);

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -817,7 +817,16 @@ LogicalResult MemDescSubsliceOp::verify() {
817817
}
818818

819819
auto ctx = getContext();
820-
auto ll = triton::gpu::toLinearLayout(srcTy);
820+
LinearLayout ll;
821+
if (auto paddedEncoding = dyn_cast<PaddedSharedEncodingAttr>(srcEnc)) {
822+
if (paddedEncoding.getRank() < srcTy.getRank()) {
823+
return emitError("SubSlice of low rank PaddedSharedEncoding from higher "
824+
"rank tensors is not supported yet");
825+
}
826+
ll = paddedEncoding.getLinearComponent();
827+
} else {
828+
ll = triton::gpu::toLinearLayout(srcTy);
829+
}
821830
// NYI: We don't support non-trivial block dimension for now.
822831
auto kBlock = mlir::StringAttr::get(getContext(), "block");
823832
if (ll.getInDimSize(kBlock) != 1) {

lib/Dialect/TritonGPU/IR/Types.cpp

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -177,29 +177,19 @@ LogicalResult MemDescType::verify(function_ref<InFlightDiagnostic()> emitError,
177177
<< "the shape size when pipelining.";
178178
}
179179

180-
// Subslices are not yet implemented
181-
auto subsliceAllocSize =
182-
allocShape.drop_front(allocShape.size() - shape.size());
183-
for (auto [allocDim, shapeDim] : llvm::zip(shape, subsliceAllocSize)) {
184-
if (allocDim != shapeDim) {
185-
return emitError() << "Subslices with padded encodings are not yet "
186-
<< "implemented.";
187-
}
188-
}
189-
190180
// Ensure linear component's outDims match the alloc size ignoring
191181
// pipelining dimension
192182
auto outDims = standardOutDimNames(ctx, rank);
193183
const auto &ll = enc.getLinearComponent();
194-
auto expectedShape = shape;
195-
if (shape.size() == allocShape.size() && shape.size() == rank + 1)
184+
auto expectedShape = allocShape;
185+
if (rank == allocShape.size() - 1)
196186
expectedShape = expectedShape.drop_front(1);
197187

198188
for (auto d = 0; d < rank; d++) {
199189
if (ll.getOutDimSize(outDims[d]) != expectedShape[d]) {
200190
return emitError() << "Mismatch in expected shape for dimension " << d
201-
<< ". Expected: " << ll.getOutDimSize(outDims[d])
202-
<< ", got: " << expectedShape[d];
191+
<< ". Expected: " << expectedShape[d]
192+
<< ", got: " << ll.getOutDimSize(outDims[d]);
203193
}
204194
}
205195
}

python/test/gluon/test_core.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,3 +897,58 @@ def kernel(x, y):
897897

898898
compiled_kernel = kernel.warmup(input, output, grid=(1, ))
899899
assert compiled_kernel.asm["ttgir"].count("tt.func private") == 0
900+
901+
902+
@pytest.mark.parametrize("interval_pairs", [[[32, 4]], [[16, 4]], [[16, 4], [64, 8]]])
903+
@pytest.mark.parametrize(
904+
"shared_layout",
905+
[{"order": [0, 1]}, {"order": [1, 0]},
906+
{"offsets": [[0, 1], [0, 2], [0, 8], [0, 4], [0, 16], [0, 32], [2, 0], [1, 0], [4, 0], [8, 0], [16, 0], [32, 0]]}])
907+
@pytest.mark.parametrize("slice_m_offset, slice_n_offset, slice_m, slice_n", [(48, 16, 16, 16), (32, 48, 32, 16),
908+
(48, 32, 16, 32)])
909+
def test_padded_shared_layout_subslice(interval_pairs, shared_layout, slice_m_offset, slice_n_offset, slice_m, slice_n):
910+
m = 64
911+
n = 64
912+
num_warps = 1
913+
num_warps_cst = ttgl.constexpr(num_warps)
914+
warp_size_cst = ttgl.constexpr(THREADS_PER_WARP)
915+
916+
shape = [m, n]
917+
if "order" in shared_layout:
918+
order = shared_layout["order"]
919+
smem_layout = ttgl.constexpr(ttgl.PaddedSharedLayout.with_identity_for(interval_pairs, shape, order))
920+
elif "offsets" in shared_layout:
921+
offsets = shared_layout["offsets"]
922+
blocks = []
923+
smem_layout = ttgl.constexpr(ttgl.PaddedSharedLayout(interval_pairs, offsets, blocks, shape))
924+
925+
@gluon.jit
926+
def kernel(in_ptr, out_ptr, M: ttgl.constexpr, N: ttgl.constexpr, SLICE_M_OFFSET: ttgl.constexpr,
927+
SLICE_N_OFFSET: ttgl.constexpr, SLICE_M: ttgl.constexpr, SLICE_N: ttgl.constexpr):
928+
blocked: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [warp_size_cst, 1], [1, num_warps_cst], [1, 0])
929+
offs_m_load = ttgl.arange(0, M, ttgl.SliceLayout(1, blocked))
930+
offs_n_load = ttgl.arange(0, N, ttgl.SliceLayout(0, blocked))
931+
in_offs = offs_m_load[:, None] * N + offs_n_load[None, :]
932+
933+
in_data = ttgl.load(in_ptr + in_offs)
934+
935+
smem = ttgl.allocate_shared_memory(ttgl.int32, [M, N], smem_layout)
936+
smem_slice0 = smem.slice(SLICE_M_OFFSET, SLICE_M, dim=0)
937+
smem_slice1 = smem_slice0.slice(SLICE_N_OFFSET, SLICE_N, dim=1)
938+
939+
smem.store(in_data)
940+
941+
out_data = smem_slice1.load(blocked)
942+
943+
offs_m_store = ttgl.arange(0, SLICE_M, ttgl.SliceLayout(1, blocked))
944+
offs_n_store = ttgl.arange(0, SLICE_N, ttgl.SliceLayout(0, blocked))
945+
out_offs = offs_m_store[:, None] * SLICE_N + offs_n_store[None, :]
946+
ttgl.store(out_ptr + out_offs, out_data)
947+
948+
input = torch.arange(m * n, device="cuda").reshape(m, n).to(torch.int32)
949+
output = torch.zeros((slice_m, slice_n), dtype=torch.int32, device="cuda")
950+
ref_output = input[slice_m_offset:slice_m_offset + slice_m, slice_n_offset:slice_n_offset + slice_n]
951+
952+
kernel[(1, )](input, output, m, n, slice_m_offset, slice_n_offset, slice_m, slice_n, num_warps=num_warps)
953+
954+
assert (output == ref_output).all()

python/test/gluon/test_frontend.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2375,3 +2375,29 @@ def test_layout_zeros():
23752375
# CHECK: #blocked = #ttg.blocked
23762376
# CHECK: arith.constant dense<0.000000e+00> : tensor<128xf32, #blocked>
23772377
ttgl.zeros([128], ttgl.float32, layout=ttgl.BlockedLayout([1], [32], [4], [0]))
2378+
2379+
2380+
@gluon.jit
2381+
def print_num_warps():
2382+
num_warps: ttgl.constexpr = ttgl.num_warps()
2383+
print("num_warps", num_warps)
2384+
2385+
2386+
@filecheck_test
2387+
@gluon.jit
2388+
def test_get_num_warps():
2389+
# CHECK-LABEL: test_get_num_warps
2390+
# CHECK: tt.func private @{{.*}}print_num_warps
2391+
# CHECK-NEXT arith.constant 4 : i32
2392+
2393+
# CHECK: tt.func private @{{.*}}print_num_warps{{.*}}NW1
2394+
# CHECK-NEXT arith.constant 1 : i32
2395+
2396+
# CHECK: tt.func private @{{.*}}print_num_warps{{.*}}NW2
2397+
# CHECK-NEXT arith.constant 2 : i32
2398+
2399+
# CHECK: tt.func private @{{.*}}print_num_warps{{.*}}NW8
2400+
# CHECK-NEXT arith.constant 8 : i32
2401+
print_num_warps()
2402+
ttgl.warp_specialize((), print_num_warps, (), [print_num_warps, print_num_warps, print_num_warps], [1, 2, 8],
2403+
[24, 24, 24])

python/triton/experimental/gluon/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
expand_dims,
4848
full,
4949
gather,
50+
num_warps,
5051
histogram,
5152
inline_asm_elementwise,
5253
join,

python/triton/experimental/gluon/language/_core.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,14 @@ def warp_specialize(default_args, default_partition, worker_args, worker_partiti
501501
worker_num_regs, _generator)
502502

503503

504+
@builtin
505+
def num_warps(_semantic=None, _generator=None):
506+
"""
507+
Returns the number of warps that execute the current context, including in warp-specialized regions.
508+
"""
509+
return _semantic.num_warps(_generator)
510+
511+
504512
@builtin
505513
def thread_barrier(_semantic=None):
506514
"""

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,3 +427,9 @@ def warp_specialize(self, default_args, default_partition, worker_args, worker_p
427427
if default_results is None:
428428
return
429429
return tuple(unflatten_ir_values(mlir_results, [r.type for r in default_results]))
430+
431+
def num_warps(self, generator):
432+
if generator.caller_context is not None:
433+
assert isinstance(generator.caller_context, GluonCallerContext)
434+
return ttgl.constexpr(generator.caller_context.num_warps)
435+
return ttgl.constexpr(self.builder.options.num_warps)

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,30 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n
510510

511511
// -----
512512

513+
// CHECK-LABEL: padded_shared_layout_subslice_load_store
514+
515+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [2, 1], order = [1, 0]}>
516+
#shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [32, 32]}>
517+
#smem = #ttg.shared_memory
518+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [2, 1], instrShape = [16, 16], isTransposed = true}>
519+
module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 64 : i32} {
520+
tt.func @padded_shared_layout_subslice_load_store(%arg0: tensor<32x32xf16, #blocked>) {
521+
// CHECK: llvm.store {{.*}} : vector<8xf16>, !llvm.ptr<3>
522+
// CHECK-NOT: llvm.store
523+
%0 = ttg.local_alloc %arg0 : (tensor<32x32xf16, #blocked>) -> !ttg.memdesc<32x32xf16, #shared, #smem, mutable>
524+
%1 = ttg.memdesc_subslice %0 [16, 0] : !ttg.memdesc<32x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<16x32xf16, #shared, #smem, mutable, 32x32>
525+
// CHECK-COUNT-2: llvm.load {{.*}} : !llvm.ptr<3> -> vector<4xf16>
526+
// CHECK-NOT: llvm.load
527+
%2 = ttg.local_load %1: !ttg.memdesc<16x32xf16, #shared, #smem, mutable, 32x32> -> tensor<16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
528+
// CHECK-COUNT-2: llvm.store {{.*}} : vector<4xf16>, !llvm.ptr<3>
529+
// CHECK-NOT: llvm.store
530+
ttg.local_store %2, %1 : tensor<16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -> !ttg.memdesc<16x32xf16, #shared, #smem, mutable, 32x32>
531+
tt.return
532+
}
533+
}
534+
535+
// -----
536+
513537
// GFX950-LABEL: reduce_32x32
514538
// GFX950: llvm.call_intrinsic "llvm.amdgcn.permlane32.swap"
515539
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {

0 commit comments

Comments
 (0)