Skip to content

Commit bfc51a4

Browse files
Merge commit '68aa962e67baa191cec5aac173255abdba80db1a'
2 parents a405aa2 + 68aa962 commit bfc51a4

File tree

7 files changed

+79
-101
lines changed

7 files changed

+79
-101
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ python/triton/backends/
1717
!python/triton/backends/compiler.py
1818
!python/triton/backends/driver.py
1919

20+
# Language extras
21+
python/triton/language/extra
22+
2023
# Proton
2124
python/triton/profiler
2225

lib/Analysis/AxisInfo.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,11 @@ class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
278278
private:
279279
int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
280280
int dim) override {
281+
// Contiguity assumes an increasing sequence. So for SubIOp contiguous
282+
// RHS doesn't produce a contiguous result.
283+
if (isa<arith::SubIOp>(op))
284+
return gcd(lhs.getContiguity(dim), rhs.getConstancy(dim));
285+
281286
return std::max(gcd(lhs.getConstancy(dim), rhs.getContiguity(dim)),
282287
gcd(lhs.getContiguity(dim), rhs.getConstancy(dim)));
283288
}

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 45 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,9 @@ static void createTMAAsyncCopy(
219219
// encodings, raise assertion, since incompatible shared encoding has been
220220
// handled in splitLoadsForIncompatible.
221221
static std::optional<ttg::SharedEncodingAttr>
222-
getSharedEncIfAllUsersAreDotEnc(Value val) {
222+
getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
223223
ttg::SharedEncodingAttr attr;
224+
incompatible = false;
224225
for (Operation *user : val.getUsers()) {
225226
ttg::SharedEncodingAttr tempAttr;
226227
if (user->getNumResults() != 1)
@@ -230,7 +231,8 @@ getSharedEncIfAllUsersAreDotEnc(Value val) {
230231
// First time we find a shared encoding in the chain, save it and try to
231232
// use it if it is compatible with the other users.
232233
tempAttr = cast<ttg::SharedEncodingAttr>(memDesc.getEncoding());
233-
if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0)).has_value())
234+
if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0), incompatible)
235+
.has_value())
234236
return std::nullopt;
235237
} else {
236238
if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
@@ -248,8 +250,10 @@ getSharedEncIfAllUsersAreDotEnc(Value val) {
248250
bitWidth, /*needTrans=*/false);
249251
}
250252
// Check that the shared encodings needed by the users are compatible.
251-
if (attr != nullptr)
252-
assert(attr == tempAttr && "incompatible shared encoding");
253+
if (attr != nullptr && attr != tempAttr) {
254+
incompatible = true;
255+
return std::nullopt;
256+
}
253257
attr = tempAttr;
254258
}
255259
return attr;
@@ -439,8 +443,44 @@ assignMemoryLayouts(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
439443
loadInfo.sharedEncoding =
440444
getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr);
441445
} else if (auto dot = dyn_cast<tt::DotOp>(use)) {
446+
bool incompatible = false;
442447
loadInfo.sharedEncoding =
443-
getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr);
448+
getSharedEncIfAllUsersAreDotEnc(op->getResult(0), incompatible)
449+
.value_or(nullptr);
450+
// If we can't agree on a shared encoding skip pipelinig the load.
451+
if (incompatible)
452+
continue;
453+
454+
// HACK: Triton LLVM codegen has a bug where local_loads from #shared to
455+
// #mma layout can lead to invalid code if the loaded shape is smaller
456+
// than the mma tile (e.g. loading a 128x1 tensor for an MMAv2 dot with
457+
// tile {16,8} is bad because 1 < 8). To work around this, don't
458+
// pipeline such loads.
459+
//
460+
// The codegen bug is caught by an assertion, so if you think you've
461+
// fixed it, feel free to delete this code and see if the assert still
462+
// fails. :)
463+
if (!loadInfo.sharedEncoding) {
464+
if (auto dotEnc = dyn_cast<ttg::NvidiaMmaEncodingAttr>(
465+
dot.getResult().getType().getEncoding())) {
466+
auto loadTy = cast<RankedTensorType>(op->getResultTypes()[0]);
467+
auto mmaInstrShape = dotEnc.getInstrShape();
468+
if (loadTy.getRank() < mmaInstrShape.size())
469+
continue;
470+
bool ok = true;
471+
for (int i = 0; i < mmaInstrShape.size(); i++) {
472+
if (loadTy.getShape()[loadTy.getRank() - mmaInstrShape.size() +
473+
i] < mmaInstrShape[i]) {
474+
ok = false;
475+
break;
476+
}
477+
}
478+
// If this load might trigger the bug, don't do the fallback logic
479+
// below, which might allow the load to be pipelined.
480+
if (!ok)
481+
continue;
482+
}
483+
}
444484
}
445485
} else if (auto loadOp = dyn_cast<tt::LoadOp>(use)) {
446486
// The use of this loadOp is another loadOp. If the use is not in the
@@ -476,83 +516,6 @@ assignMemoryLayouts(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
476516
return loadToInfo;
477517
}
478518

479-
// Split users to groups, each group has the same shared encoding.
480-
// If not all users are Dot encoding, return empty vector.
481-
static DenseMap<ttg::SharedEncodingAttr, SmallVector<Operation *>>
482-
handleIncompatibleSharedEncoding(Operation *loadOp) {
483-
DenseMap<ttg::SharedEncodingAttr, SmallVector<Operation *>> loadGroups;
484-
// Go through transitive uses of the loadOp in the same block.
485-
for (Operation *user : loadOp->getUsers()) {
486-
if (user->getBlock() != loadOp->getBlock())
487-
continue;
488-
if (user->getNumResults() != 1)
489-
return loadGroups;
490-
491-
ttg::SharedEncodingAttr tempAttr;
492-
if (auto memDesc =
493-
dyn_cast<triton::MemDescType>(user->getResult(0).getType())) {
494-
tempAttr = cast<ttg::SharedEncodingAttr>(memDesc.getEncoding());
495-
loadGroups[tempAttr].push_back(user);
496-
} else {
497-
if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
498-
return loadGroups;
499-
auto dotOpEnc = dyn_cast<ttg::DotOperandEncodingAttr>(
500-
cast<TensorOrMemDesc>(user->getResult(0).getType()).getEncoding());
501-
if (!dotOpEnc)
502-
return loadGroups;
503-
auto srcTy = cast<TensorOrMemDesc>(loadOp->getResult(0).getType());
504-
auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
505-
auto order = ttg::getOrder(srcTy.getEncoding());
506-
unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
507-
tempAttr = ttg::SharedEncodingAttr::get(
508-
loadOp->getContext(), dotOpEnc, srcTy.getShape(),
509-
ttg::getOrder(srcTy.getEncoding()),
510-
ttg::getCTALayout(srcTy.getEncoding()),
511-
srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false);
512-
loadGroups[tempAttr].push_back(user);
513-
}
514-
}
515-
return loadGroups;
516-
}
517-
518-
// Clone loads so each group of uses with same shared encoding will have a
519-
// corresponding Load.
520-
static void splitLoadsForIncompatible(
521-
OpBuilder &builder, Operation *loadOp,
522-
DenseMap<ttg::SharedEncodingAttr, SmallVector<Operation *>> &lGroups) {
523-
// The first group will use the original load, create new loads for other
524-
// groups.
525-
unsigned idx = 0;
526-
builder.setInsertionPointAfter(loadOp);
527-
for (auto pair : lGroups) {
528-
SmallVector<Operation *> &group = pair.second;
529-
if (idx++ == 0)
530-
continue;
531-
Operation *newLoad = builder.clone(*loadOp);
532-
for (auto *user : group) {
533-
user->replaceUsesOfWith(loadOp->getResult(0), newLoad->getResult(0));
534-
}
535-
}
536-
}
537-
538-
static void splitLoadsWithIncompatibleEncoding(scf::ForOp forOp) {
539-
// Get the list of all loads.
540-
SmallVector<Operation *> loads;
541-
for (Operation &op : forOp.getBody()->without_terminator()) {
542-
if (isa<tt::LoadOp, tt::ExperimentalDescriptorLoadOp>(op)) {
543-
loads.push_back(&op);
544-
}
545-
}
546-
OpBuilder builder(forOp);
547-
for (auto *loadOp : loads) {
548-
auto lGroups = handleIncompatibleSharedEncoding(loadOp);
549-
LDBG("groups with different encoding: " << lGroups.size() << " "
550-
<< *loadOp);
551-
if (lGroups.size() > 1)
552-
splitLoadsForIncompatible(builder, loadOp, lGroups);
553-
}
554-
}
555-
556519
static llvm::MapVector<Operation *, LoadInfo>
557520
scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule,
558521
DenseSet<Operation *> &rootUsers, int numStages) {
@@ -1106,8 +1069,6 @@ static void invalidateBarriers(OpBuilder &builder,
11061069

11071070
bool mlir::triton::preProcessLoopAndGetSchedule(
11081071
scf::ForOp &forOp, int numStages, mlir::triton::PipeliningOption &options) {
1109-
splitLoadsWithIncompatibleEncoding(forOp);
1110-
11111072
// Schedule the loads and root ops (dot ops) in the loop. This will give us
11121073
// a scaffold for the final schedule.
11131074
DenseSet<Operation *> rootUsers;

python/setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from io import BytesIO
1515
from distutils.command.clean import clean
1616
from pathlib import Path
17-
from typing import NamedTuple
17+
from typing import List, NamedTuple
1818

1919
from setuptools import Extension, setup
2020
from setuptools.command.build_ext import build_ext
@@ -32,8 +32,8 @@
3232
@dataclass
3333
class Backend:
3434
name: str
35-
package_data: list[str]
36-
language_package_data: list[str]
35+
package_data: List[str]
36+
language_package_data: List[str]
3737
src_dir: str
3838
backend_dir: str
3939
language_dir: str

python/test/regression/test_functional_regressions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,21 @@ def grid(META):
226226
torch.testing.assert_close(torch_output, triton_output, rtol=1e-2, atol=1e-2)
227227

228228

229+
def test_reverse_range(device):
230+
231+
@triton.jit
232+
def kernel(in_ptr, out_ptr):
233+
x0 = tl.arange(0, 512)
234+
tmp0 = tl.load(in_ptr + (512 - x0))
235+
tl.store(out_ptr + x0, tmp0)
236+
237+
data = torch.randn((516, ), dtype=torch.float32, device=device)
238+
res = torch.empty((512, ), dtype=torch.float32, device=device)
239+
kernel[(1, )](data, res)
240+
ref = torch.flip(data[1:513], [0])
241+
assert (res == ref).all()
242+
243+
229244
@triton.jit
230245
def _triton_cummax_helper_fn(arg0_0, arg0_1, arg1_0, arg1_1):
231246
tmp0 = arg0_0 > arg1_0

test/Analysis/test-alignment.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,12 @@ tt.func @sub() {
9797
%1 = arith.constant dense<1> : tensor<128xi32>
9898
// CHECK-NEXT: contiguity = [128], divisibility = [1], constancy = [1], constant_value = <none>
9999
%2 = arith.subi %0, %1 : tensor<128xi32>
100+
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
101+
%3 = arith.subi %1, %0 : tensor<128xi32>
100102
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 129
101-
%3 = arith.constant dense<129> : tensor<128xi32>
103+
%4 = arith.constant dense<129> : tensor<128xi32>
102104
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
103-
%4 = arith.subi %3, %1 : tensor<128xi32>
105+
%5 = arith.subi %4, %1 : tensor<128xi32>
104106
tt.return
105107
}
106108

test/TritonGPU/loop-pipeline.mlir

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -844,16 +844,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
844844
%14 = tt.broadcast %11 : tensor<1x16x!tt.ptr<f16>, #blocked> -> tensor<64x16x!tt.ptr<f16>, #blocked>
845845
%15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked>
846846
%16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr<f16>, #blocked>, tensor<64x16xi32, #blocked>
847-
// check that the load with incompatiable shared encoding gets cloned and feeds into uses with same encoding
848-
// AMD-NOT: alloc
849-
// AMD: scf.for
850-
// CHECK: local_alloc
851-
// CHECK: local_alloc
852-
// CHECK: scf.for
853-
// CHECK: local_load {{.*}} tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1
854-
// CHECK: convert_layout {{.*}} tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0
855-
// CHECK: tt.dot
856-
// CHECK: tt.trans %arg
847+
// check that the load didn't get pipelined.
848+
// COMMON-NOT: alloc
849+
// COMMON: scf.for
857850
%17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 {
858851
%18 = tt.load %16 : tensor<64x16x!tt.ptr<f16>, #blocked>
859852
%19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
@@ -1460,8 +1453,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 :
14601453
// -----
14611454

14621455
// COMMON-LABEL: @dont_pipeline_128x1
1463-
// AMD-NOT: local_load{{.*}}128x1
1464-
// CHECK: local_load{{.*}}128x1
1456+
// COMMON-NOT: local_load{{.*}}128x1
14651457
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
14661458
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
14671459
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {

0 commit comments

Comments
 (0)