Skip to content

Commit 699ff73

Browse files
Merge commit 'cfe3dd079865d9f23cc6d7854de707636892cacf'
2 parents 688adf6 + cfe3dd0 commit 699ff73

File tree

95 files changed

+1099
-771
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

95 files changed

+1099
-771
lines changed

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "triton/Conversion/MLIRTypes.h"
55

66
namespace mlir::triton {
7+
enum class ProgramIDDim : uint32_t;
78

89
class TargetInfoBase {
910
public:
@@ -48,7 +49,7 @@ class TargetInfoBase {
4849
Value i) const = 0;
4950

5051
virtual Value programId(RewriterBase &rewriter, Location loc,
51-
ModuleOp moduleOp, int axis) const = 0;
52+
ModuleOp moduleOp, ProgramIDDim axis) const = 0;
5253

5354
virtual bool warpReduce(RewriterBase &rewriter, Location loc,
5455
SmallVector<Value> &acc, triton::ReduceOp op,

lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,10 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern<triton::PrintOp> {
2626
ConversionPatternRewriter &rewriter) const override {
2727
auto loc = op->getLoc();
2828

29-
auto getPid = [&](int axis) {
30-
return targetInfo.programId(rewriter, loc,
31-
op->getParentOfType<ModuleOp>(), axis);
32-
};
33-
std::array<Value, 3> pid = {getPid(0), getPid(1), getPid(2)};
29+
std::array<Value, 3> pid;
30+
auto module = op->getParentOfType<ModuleOp>();
31+
for (auto axis : {ProgramIDDim::X, ProgramIDDim::Y, ProgramIDDim::Z})
32+
pid[(int)axis] = targetInfo.programId(rewriter, loc, module, axis);
3433

3534
// Simple printf of a string without any tensors.
3635
if (op.getNumOperands() == 0) {

lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@ struct GetProgramIdOpConversion
1717
LogicalResult
1818
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
1919
ConversionPatternRewriter &rewriter) const override {
20-
Value programId = targetInfo.programId(rewriter, op->getLoc(),
21-
op->getParentOfType<ModuleOp>(),
22-
op.getAxisAsInt());
20+
Value programId = targetInfo.programId(
21+
rewriter, op->getLoc(), op->getParentOfType<ModuleOp>(), op.getAxis());
2322
rewriter.replaceOp(op, programId);
2423
return success();
2524
}

lib/Dialect/Gluon/Transforms/ResolveAutoEncodings.cpp

Lines changed: 97 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "mlir/IR/Attributes.h"
12
#include "mlir/IR/BuiltinAttributes.h"
23
#include "mlir/IR/Visitors.h"
34
#include "mlir/Support/LLVM.h"
@@ -9,6 +10,8 @@
910
#include "llvm/ADT/PriorityWorklist.h"
1011
#include "llvm/Support/Debug.h"
1112
#include "llvm/Support/LogicalResult.h"
13+
#include "llvm/Support/raw_ostream.h"
14+
#include "llvm/Support/xxhash.h"
1215

1316
namespace ttg = mlir::triton::gpu;
1417

@@ -28,6 +31,65 @@ bool isAutoEncodingTensorType(Type ty) {
2831
return tensorTy && isa<gluon::AutoEncodingAttr>(tensorTy.getEncoding());
2932
}
3033

34+
struct LayoutInfo {
35+
Attribute encoding;
36+
// Some operations can infer one of many encodings,
37+
// we model this by setting the mayVary flag on encodings
38+
// derived from these ops.
39+
// If "may vary" is set then we allow conflicts, and when
40+
// resolving conflicts we prefer encodings that are not allowed to vary.
41+
bool mayVary = false;
42+
43+
operator bool() { return bool(encoding); }
44+
};
45+
46+
uint64_t hashWithMemo(Attribute attr,
47+
llvm::MapVector<Attribute, uint64_t> &hashMemo) {
48+
auto it = hashMemo.find(attr);
49+
if (it != hashMemo.end()) {
50+
return it->second;
51+
}
52+
53+
// llvm::hash_value is not stable, so instead we hash the string repr of the
54+
// attribute
55+
std::string str;
56+
llvm::raw_string_ostream os(str);
57+
attr.print(os);
58+
auto hash = llvm::xxh3_64bits(str);
59+
hashMemo.try_emplace(attr, hash);
60+
return hash;
61+
}
62+
63+
bool compare(Attribute a, Attribute b,
64+
llvm::MapVector<Attribute, uint64_t> &hashMemo) {
65+
if (a == b)
66+
return false;
67+
68+
return hashWithMemo(a, hashMemo) > hashWithMemo(b, hashMemo);
69+
}
70+
71+
LayoutInfo combineInfo(LayoutInfo lhs, LayoutInfo rhs, Operation *op,
72+
llvm::MapVector<Attribute, uint64_t> &hashMemo) {
73+
// Sort inputs so this operation is commutative
74+
if (compare(lhs.encoding, rhs.encoding, hashMemo)) {
75+
std::swap(lhs, rhs);
76+
}
77+
if (lhs.mayVary)
78+
return rhs;
79+
if (rhs.mayVary)
80+
return lhs;
81+
if (lhs.encoding == rhs.encoding)
82+
return lhs;
83+
op->emitOpError("found conflicting encodings for value:\n ")
84+
<< lhs.encoding << "\nand\n " << rhs.encoding;
85+
return {};
86+
}
87+
88+
bool encodingsMayVary(Operation *op) {
89+
return isa<triton::JoinOp, triton::SplitOp, triton::ReshapeOp, triton::CatOp,
90+
triton::TransOp>(op);
91+
}
92+
3193
LogicalResult inferAutoLayouts(FuncOp func) {
3294
// Disallow auto encoding accross function call boundaries
3395
for (auto argTy : func.getArgumentTypes()) {
@@ -42,33 +104,37 @@ LogicalResult inferAutoLayouts(FuncOp func) {
42104
"Functions returning auto encoding must be fully inlined");
43105
}
44106

45-
llvm::MapVector<Value, Attribute> valueToEncoding;
107+
llvm::MapVector<Value, LayoutInfo> valueToEncoding;
46108
llvm::PriorityWorklist<Value> worklist;
109+
llvm::MapVector<Attribute, uint64_t> hashMemo;
47110

48111
auto updateEncoding = [&](ArrayRef<Value> values,
49-
Attribute enc) -> LogicalResult {
112+
LayoutInfo info) -> LogicalResult {
50113
for (auto value : values) {
51-
auto [it, inserted] = valueToEncoding.insert({value, enc});
114+
auto [it, inserted] = valueToEncoding.insert({value, info});
52115
if (!inserted) {
53-
if (it->second != enc) {
54-
auto defOp = value.getDefiningOp();
55-
auto op = defOp ? defOp : func;
56-
return op->emitOpError("Found conflicting encodings for value");
57-
}
58-
} else {
59-
LLVM_DEBUG({
60-
DBGS() << "Setting value:\n\t" << value << "\nto encoding:\n\t" << enc
61-
<< "\n";
62-
});
63-
worklist.insert(value);
116+
auto defOp = value.getDefiningOp();
117+
auto op = defOp ? defOp : func;
118+
auto combine = combineInfo(it->second, info, op, hashMemo);
119+
if (!combine)
120+
return failure();
121+
if (combine == it->second)
122+
continue;
123+
it->second = combine;
64124
}
125+
LLVM_DEBUG({
126+
DBGS() << "Setting value:\n\t" << value << "\nto encoding:\n\t"
127+
<< it->second << "\n";
128+
});
129+
worklist.insert(value);
65130
}
66131
return success();
67132
};
68133

69134
// 1. Set seed values from set_auto_layout ops
70135
auto res = func.walk([&](gluon::SetAutoLayoutOp op) -> WalkResult {
71-
return updateEncoding({op.getSrc()}, op.getType().getEncoding());
136+
return updateEncoding({op.getSrc()},
137+
LayoutInfo{op.getType().getEncoding()});
72138
});
73139

74140
if (res.wasInterrupted())
@@ -77,26 +143,28 @@ LogicalResult inferAutoLayouts(FuncOp func) {
77143
// 2. Propagate encodings through the graph until fixed point, or conflict
78144
while (!worklist.empty()) {
79145
auto val = worklist.pop_back_val();
80-
auto enc = valueToEncoding[val];
81-
assert(enc);
146+
auto info = valueToEncoding[val];
147+
assert(info);
82148

83149
// Propagate to users
84150
for (OpOperand &use : val.getUses()) {
85151
auto op = use.getOwner();
86152
if (isa<scf::ForOp, scf::WhileOp>(op)) {
87153
auto offset = 3 * isa<scf::ForOp>(op);
88154
auto tiedArgs = getTiedArgs(op, use.getOperandNumber() - offset);
89-
if (failed(updateEncoding(tiedArgs, enc)))
155+
if (failed(updateEncoding(tiedArgs, info)))
90156
return failure();
91157
} else if (isa<scf::YieldOp>(op)) {
92158
auto tiedArgs = getTiedArgs(op, use.getOperandNumber());
93-
if (failed(updateEncoding(tiedArgs, enc)))
159+
if (failed(updateEncoding(tiedArgs, info)))
94160
return failure();
95161
} else {
96-
auto dstEnc = inferDstEncoding(op, enc);
162+
auto dstEnc = inferDstEncoding(op, info.encoding);
97163
if (dstEnc) {
164+
bool mayVary = info.mayVary || encodingsMayVary(op);
165+
LayoutInfo dstInfo{dstEnc, mayVary};
98166
if (failed(updateEncoding(llvm::to_vector_of<Value>(op->getResults()),
99-
dstEnc)))
167+
dstInfo)))
100168
return failure();
101169
}
102170
}
@@ -107,17 +175,19 @@ LogicalResult inferAutoLayouts(FuncOp func) {
107175
auto definingOp = opResult.getOwner();
108176
if (isa<scf::ForOp, scf::WhileOp, scf::IfOp>(definingOp)) {
109177
auto tiedArgs = getTiedArgs(definingOp, opResult.getResultNumber());
110-
if (failed(updateEncoding(tiedArgs, enc)))
178+
if (failed(updateEncoding(tiedArgs, info)))
111179
return failure();
112180
} else {
113-
auto srcEncoding = inferSrcEncoding(definingOp, enc);
181+
auto srcEncoding = inferSrcEncoding(definingOp, info.encoding);
114182
if (srcEncoding) {
183+
bool mayVary = info.mayVary || encodingsMayVary(definingOp);
184+
LayoutInfo srcInfo{srcEncoding, mayVary};
115185
llvm::SmallVector<Value> tensorOperands;
116186
for (auto operand : definingOp->getOperands())
117187
if (isa<RankedTensorType>(operand.getType()))
118188
tensorOperands.push_back(operand);
119189

120-
if (failed(updateEncoding(tensorOperands, srcEncoding)))
190+
if (failed(updateEncoding(tensorOperands, srcInfo)))
121191
return failure();
122192
}
123193
}
@@ -126,18 +196,18 @@ LogicalResult inferAutoLayouts(FuncOp func) {
126196
if (isa<scf::ForOp, scf::WhileOp>(parentOp)) {
127197
auto offset = isa<scf::ForOp>(parentOp);
128198
auto tiedArgs = getTiedArgs(parentOp, blockArg.getArgNumber() - offset);
129-
if (failed(updateEncoding(tiedArgs, enc)))
199+
if (failed(updateEncoding(tiedArgs, info)))
130200
return failure();
131201
}
132202
}
133203
}
134204

135205
// 3. Transfer propagated encodings into the graph
136206
auto ctx = func.getContext();
137-
for (auto &[val, enc] : valueToEncoding) {
207+
for (auto &[val, info] : valueToEncoding) {
138208
auto existingTy = cast<RankedTensorType>(val.getType());
139209
assert(isa<gluon::AutoEncodingAttr>(existingTy.getEncoding()));
140-
auto ty = existingTy.cloneWithEncoding(enc);
210+
auto ty = existingTy.cloneWithEncoding(info.encoding);
141211
val.setType(ty);
142212

143213
if (auto opResult = dyn_cast<OpResult>(val)) {

python/test/unit/language/test_core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4527,6 +4527,9 @@ def make_finite(x, dtype):
45274527
assert 'st.global.v4' in ptx
45284528
assert (re.search(r'(mma|wgmma.mma_async).sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.(f|bf)16.(f|bf)16', ptx)
45294529
or "tcgen05.mma.cta_group::1.kind::f16" in ptx)
4530+
if is_hip_cdna4() and normal_type in ["bf16", "fp16"]:
4531+
amdgcn = pgm.asm['amdgcn']
4532+
assert (re.search(r"v_cvt_scalef32_pk_.*?(fp4|fp8|bf8).*?op_sel", amdgcn))
45304533

45314534

45324535
@pytest.mark.interpreter

python/triton/language/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
join,
8585
load,
8686
make_block_ptr,
87-
map_elementwise, # noqa
87+
map_elementwise,
8888
max_constancy,
8989
max_contiguous,
9090
maximum,
@@ -209,6 +209,7 @@
209209
"log",
210210
"log2",
211211
"make_block_ptr",
212+
"map_elementwise",
212213
"math",
213214
"max",
214215
"max_constancy",

python/triton/runtime/interpreter.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,6 @@ def validate(self):
9393
for stride in self.strides[:-1]:
9494
assert stride.data.item() % 16 == 0, "stride must be 16-byte aligned"
9595
assert self.strides[-1].data.item() == 1, "last dim must be contiguous"
96-
for i in range(self.ndim - 1):
97-
stride = self.strides[i].data.item()
98-
prev_stride = self.strides[i + 1].data.item()
99-
prev_size = self.shape[i + 1].data.item()
100-
assert stride >= prev_stride, "strides must be ordered largest to smallest"
101-
assert (stride % prev_stride) == 0, "strides must be even multiples of smaller strides"
102-
assert (stride // prev_stride) >= prev_size, "invalid stride"
10396

10497
def materialize_pointers(self, offsets: List[TensorHandle]):
10598
assert len(offsets) == self.ndim

test/Analysis/amd/test-alignment.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
44

5-
tt.func public @kernel(%arg0: tensor<256x64xf16, #mma> {tt.contiguity=256 : i32, tt.divisibility=6: i32, tt.constancy=1: i32}) attributes {noinline = false} {
5+
tt.func public @kernel(%arg0: tensor<256x64xf16, #mma> {tt.contiguity=256 : i32, tt.divisibility=6: i32, tt.constancy=1: i32}) {
66
// expeted-remark @below {{contiguity = [128, 32], divisibility = [6, 6], constancy = [1, 1], constant_value = <none>}}
77
%0 = amdgpu.extract_slice %arg0 [128, 32] : tensor<256x64xf16, #mma> to tensor<128x32xf16, #mma>
88
tt.return

test/Conversion/amd/buffer_atomic_cas.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
33
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
44
// CHECK-LABEL: buffer_atomic_cas_i64
5-
tt.func public @buffer_atomic_cas_i64(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) attributes {noinline = false} {
5+
tt.func public @buffer_atomic_cas_i64(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
66
// CHECK: %[[cas_val:.*]] = llvm.mlir.constant(2 : i64) : i64
77
// CHECK: %[[cas_val_cast:.*]] = llvm.bitcast %[[cas_val]] : i64 to i64
88
// CHECK: %[[cas_val_insert:.*]] = llvm.insertvalue %[[cas_val_cast]], %{{.*}}[1] : !llvm.struct<(i64, i64)>

test/Conversion/amd/buffer_load_store.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
262262
#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
263263
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
264264
// CHECK-LABEL: strided_buffer_load_and_store
265-
tt.func public @strided_buffer_load_and_store(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) attributes {noinline = false} {
265+
tt.func public @strided_buffer_load_and_store(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
266266
%cst = arith.constant dense<2> : tensor<1024xi32, #blocked>
267267
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
268268
%1 = arith.muli %0, %cst : tensor<1024xi32, #blocked>

0 commit comments

Comments
 (0)