Skip to content

Commit e4d0843

Browse files
Merge commit '35f1827581071a5ac3a385f8776ab1a3a784811a'
2 parents a5393bf + 35f1827 commit e4d0843

File tree

6 files changed

+135
-20
lines changed

6 files changed

+135
-20
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,8 @@ def TT_ScanReturnOp: TT_Op<"scan.return",
778778
def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise,
779779
SameOperandsAndResultEncoding,
780780
SameVariadicOperandSize,
781-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
781+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
782+
ConditionallySpeculatable]> {
782783

783784
let description = [{
784785
call an external function $symbol implemented in $libpath/$libname with $args
@@ -790,6 +791,12 @@ def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise,
790791
let results = (outs TT_Type:$result);
791792

792793
let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)";
794+
795+
let extraClassDeclaration = [{
796+
// Interface method for ConditionallySpeculatable.
797+
Speculation::Speculatability getSpeculatability();
798+
}];
799+
793800
}
794801

795802
//

lib/Analysis/Utility.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,25 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
7171
}
7272

7373
unsigned threadOffset = 1;
74-
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(srcLayout)) {
75-
auto parentLayout = sliceLayout.getParent();
76-
auto threadsPerWarp = getThreadsPerWarp(parentLayout);
77-
threadOffset = threadsPerWarp[sliceLayout.getDim()];
78-
} else {
79-
auto threadsPerWarp = getThreadsPerWarp(srcLayout);
80-
auto order = getThreadOrder(srcLayout);
81-
for (unsigned i = 0; i < order.size(); i++) {
82-
if (order[i] == axis)
83-
break;
84-
threadOffset *= threadsPerWarp[order[i]];
85-
}
74+
SmallVector<int> dimsRemoved;
75+
while (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(srcLayout)) {
76+
dimsRemoved.push_back(sliceLayout.getDim());
77+
srcLayout = sliceLayout.getParent();
78+
}
79+
// In case of slice layout we want to know the axis dimension relative to the
80+
// most inner parent layout. `adjustedAxis` is the matching axis dim in the
81+
// parent layout.
82+
int adjustedAxis = axis;
83+
for (auto dim : dimsRemoved) {
84+
if (dim <= adjustedAxis)
85+
adjustedAxis++;
86+
}
87+
auto threadsPerWarp = getThreadsPerWarp(srcLayout);
88+
auto order = getThreadOrder(srcLayout);
89+
for (unsigned i = 0; i < order.size(); i++) {
90+
if (order[i] == adjustedAxis)
91+
break;
92+
threadOffset *= threadsPerWarp[order[i]];
8693
}
8794
return threadOffset;
8895
}

lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,46 @@ struct ReduceOpConversion
218218
rewriter.replaceOp(op, results);
219219
}
220220

221+
// For slice layout some ids are duplicated on multiple lanes, so we need to
222+
// handle the delinearization of laneId in a special way. We need to
223+
// generalize this part of the logic to work on any kind of linear layout
224+
// uniformely.
225+
SmallVector<Value>
226+
getMultiDimLaneId(ReduceOpHelper &helper, Value &laneId, Location &loc,
227+
ConversionPatternRewriter &rewriter) const {
228+
auto srcLayout = helper.getSrcLayout();
229+
auto srcShape = helper.getSrcShape();
230+
auto order = triton::gpu::getThreadOrder(srcLayout);
231+
SmallVector<Value> multiDimLaneId;
232+
233+
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(srcLayout)) {
234+
auto parentLayout = sliceLayout.getParent();
235+
SmallVector<unsigned> dims = {sliceLayout.getDim()};
236+
while (auto parentSliceLayout =
237+
mlir::dyn_cast<SliceEncodingAttr>(parentLayout)) {
238+
dims.push_back(parentSliceLayout.getDim());
239+
parentLayout = parentSliceLayout.getParent();
240+
}
241+
242+
auto parentThreadsPerWarps = triton::gpu::getThreadsPerWarp(parentLayout);
243+
auto parentOrder = triton::gpu::getThreadOrder(parentLayout);
244+
multiDimLaneId = delinearize(rewriter, loc, laneId, parentThreadsPerWarps,
245+
parentOrder);
246+
for (unsigned dim : llvm::reverse(dims)) {
247+
multiDimLaneId.erase(multiDimLaneId.begin() + dim);
248+
}
249+
} else {
250+
SmallVector<unsigned> threadsPerWarps =
251+
triton::gpu::getThreadsPerWarp(srcLayout);
252+
threadsPerWarps[helper.getAxis()] =
253+
triton::gpu::getThreadsPerWarpWithUniqueData(
254+
srcLayout, srcShape)[helper.getAxis()];
255+
multiDimLaneId =
256+
delinearize(rewriter, loc, laneId, threadsPerWarps, order);
257+
}
258+
return multiDimLaneId;
259+
}
260+
221261
SmallVector<Value>
222262
getMultiDimWarpId(ReduceOpHelper &helper, Value &warpId, Location &loc,
223263
ConversionPatternRewriter &rewriter) const {
@@ -231,11 +271,20 @@ struct ReduceOpConversion
231271
// a way to properly delinearize warpId in the slice case
232272
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(srcLayout)) {
233273
auto parentLayout = sliceLayout.getParent();
274+
SmallVector<unsigned> dims = {sliceLayout.getDim()};
275+
while (auto parentSliceLayout =
276+
mlir::dyn_cast<SliceEncodingAttr>(parentLayout)) {
277+
dims.push_back(parentSliceLayout.getDim());
278+
parentLayout = parentSliceLayout.getParent();
279+
}
280+
234281
auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(parentLayout);
235282
auto parentOrder = triton::gpu::getWarpOrder(parentLayout);
236283
multiDimWarpId =
237284
delinearize(rewriter, loc, warpId, parentWarpsPerCTA, parentOrder);
238-
multiDimWarpId.erase(multiDimWarpId.begin() + sliceLayout.getDim());
285+
for (unsigned dim : llvm::reverse(dims)) {
286+
multiDimWarpId.erase(multiDimWarpId.begin() + dim);
287+
}
239288
} else {
240289
SmallVector<unsigned> warpsPerCTA =
241290
triton::gpu::getWarpsPerCTA(srcLayout);
@@ -263,11 +312,8 @@ struct ReduceOpConversion
263312
unsigned axis = op.getAxis();
264313
auto smemShape = helper.getScratchRepShape();
265314

266-
auto threadsPerWarp =
267-
triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape);
268-
auto order = getThreadOrder(srcLayout);
269315
SmallVector<Value> multiDimLaneId =
270-
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
316+
getMultiDimLaneId(helper, laneId, loc, rewriter);
271317
Value laneIdAxis = multiDimLaneId[axis];
272318
Value zero = i32_val(0);
273319
Value laneZero = icmp_eq(laneIdAxis, zero);

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,6 +1039,12 @@ void ExternElementwiseOp::getEffects(
10391039
SideEffects::DefaultResource::get());
10401040
}
10411041

1042+
Speculation::Speculatability ExternElementwiseOp::getSpeculatability() {
1043+
if (getPure())
1044+
return Speculation::Speculatable;
1045+
return Speculation::NotSpeculatable;
1046+
}
1047+
10421048
// -- ExperimentalTensormapCreateOp --
10431049
LogicalResult ExperimentalTensormapCreateOp::verify() {
10441050
auto rank = getBoxDim().size();

python/test/unit/language/test_core.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
is_hip,
3232
is_hip_cdna,
3333
is_hip_mi200,
34+
is_hip_mi300,
3435
is_xpu,
3536
get_arch,
3637
torch_float8_dtypes,
@@ -3414,8 +3415,8 @@ def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack
34143415
if is_hip():
34153416
if not is_hip_cdna():
34163417
pytest.skip("scaled_dot only implemented for HIP CDNA")
3417-
if (type_a not in ["e2m1", "e5m2"]) or (type_b not in ["e2m1", "e5m2", "bf16"]):
3418-
pytest.skip(f"scaled_dot({type_a}, {type_b}) not yet implemented for HIP")
3418+
if "e4m3" in (type_a, type_b) and not is_hip_mi300():
3419+
pytest.skip(f"scaled_dot({type_a}, {type_b}) only implemented for MI300")
34193420
if mma == 16 and K == 64:
34203421
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")
34213422
if is_xpu():
@@ -6072,3 +6073,33 @@ def sanitize_cumsum_kernel(Z, X, BLOCK: tl.constexpr):
60726073
Z = torch.zeros_like(X)
60736074
sanitize_cumsum_kernel[(1, )](Z, X, BLOCK=BLOCK)
60746075
torch.testing.assert_close(Z, X.cumsum(0).to(torch.int32))
6076+
6077+
6078+
# stress test slice layout usages in reductions.
6079+
@pytest.mark.parametrize("in_shape, perm, red_dims", [
6080+
((4, 32, 32, 4, 2), [2, 1, 0, 3, 4], [3, 1, 0]),
6081+
((8, 2, 32, 4, 16), [4, 0, 1, 3, 2], [0, 2, 0]),
6082+
])
6083+
def test_chained_reductions(in_shape, perm, red_dims, device):
6084+
6085+
@triton.jit
6086+
def kernel(In, Out, #
6087+
dim_0: tl.constexpr, dim_1: tl.constexpr, dim_2: tl.constexpr, dim_3: tl.constexpr, dim_4: tl.constexpr,
6088+
perm_0: tl.constexpr, perm_1: tl.constexpr, perm_2: tl.constexpr, perm_3: tl.constexpr,
6089+
perm_4: tl.constexpr, red_dim_0: tl.constexpr, red_dim_1: tl.constexpr, red_dim_2: tl.constexpr):
6090+
idx = tl.arange(0, dim_0 * dim_1 * dim_2 * dim_3 * dim_4)
6091+
idx = idx.reshape(dim_0, dim_1, dim_2, dim_3, dim_4)
6092+
vals = tl.load(In + idx)
6093+
vals = tl.permute(vals, [perm_0, perm_1, perm_2, perm_3, perm_4])
6094+
r = tl.sum(tl.sum(tl.sum(vals, red_dim_0), red_dim_1), red_dim_2)
6095+
st_idx = tl.arange(0, r.shape[0] * r.shape[1]).reshape(r.shape)
6096+
tl.store(Out + st_idx, r)
6097+
6098+
input = torch.randint(0, 1000, in_shape, device=device, dtype=torch.int32)
6099+
temp = torch.permute(input, perm).contiguous()
6100+
ref = torch.sum(torch.sum(torch.sum(temp, dim=red_dims[0]), dim=red_dims[1]), dim=red_dims[2])
6101+
result = torch.empty_like(ref)
6102+
kernel[(1, )](input, result, input.shape[0], input.shape[1], input.shape[2], input.shape[3], input.shape[4],
6103+
perm[0], perm[1], perm[2], perm[3], perm[4], red_dims[0], red_dims[1], red_dims[2])
6104+
6105+
assert torch.all(ref == result)

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1855,3 +1855,21 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
18551855
#loc3 = loc("inner_call":29:28)
18561856
#loc4 = loc(callsite(#loc3 at #loc1))
18571857
#loc5 = loc(callsite(#loc4 at #loc2))
1858+
1859+
// -----
1860+
1861+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
1862+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
1863+
tt.func public @log1pf_scan(%39: tensor<32x16xf32, #blocked>) attributes {noinline = false} {
1864+
// CHECK: log1pf_scan
1865+
// non-speculatable ops will introduce a cond_br; extern_elementwise with pure = true should be considered speculatable.
1866+
// CHECK-NOT: llvm.cond_br
1867+
%40 = "tt.scan"(%39) <{axis = 1 : i32, reverse = false}> ({
1868+
^bb0(%arg5: f32, %arg6: f32):
1869+
%43 = tt.extern_elementwise %arg5 {libname = "", libpath = "", pure = true, symbol = "__nv_log1pf"} : (f32) -> f32
1870+
%44 = arith.addf %43, %43 : f32
1871+
tt.scan.return %44 : f32
1872+
}) : (tensor<32x16xf32, #blocked>) -> tensor<32x16xf32, #blocked>
1873+
tt.return
1874+
}
1875+
}

0 commit comments

Comments
 (0)