Skip to content

Commit f694fd7

Browse files
Merge commit '00cf53fe57332b463f02a427be65e36c91f544bc'
2 parents d478b30 + 00cf53f commit f694fd7

File tree

33 files changed

+997
-685
lines changed

33 files changed

+997
-685
lines changed

include/triton/Dialect/Triton/IR/Utility.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#ifndef TRITON_IR_UTILITY_H_
22
#define TRITON_IR_UTILITY_H_
33

4+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
5+
#include "mlir/IR/BuiltinTypes.h"
46
#include "triton/Dialect/Triton/IR/Dialect.h"
57
#include <algorithm>
68
#include <numeric>
@@ -10,6 +12,14 @@ namespace mlir {
1012
// Bitwidth of pointers
1113
constexpr int kPtrBitWidth = 64;
1214

15+
// Returns the bit width of a type, treating pointer-like types as 64-bit.
16+
// This handles LLVM dialect pointer types.
17+
inline int getIntOrFloatOrPtrBitWidth(Type type) {
18+
if (isa<LLVM::LLVMPointerType, triton::PointerType>(type))
19+
return kPtrBitWidth;
20+
return type.getIntOrFloatBitWidth();
21+
}
22+
1323
template <typename T, typename U> SmallVector<T> convertType(ArrayRef<U> in) {
1424
SmallVector<T> out;
1525
for (const auto &i : in)

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,9 @@ LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
137137
ArrayRef<unsigned> tilesPerWarp,
138138
ArrayRef<unsigned> warpsPerCTA);
139139

140-
LinearLayout chooseScaledWmmaScaleLayout(
141-
MLIRContext *ctx, int dotOperandIdx,
142-
const std::vector<std::vector<int32_t>> &dotOperandWarpBasis,
143-
ArrayRef<int64_t> dotOperandShape);
140+
LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
141+
ArrayRef<unsigned> warpsPerCTA,
142+
ArrayRef<int64_t> dotOperandShape);
144143

145144
LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx,
146145
ArrayRef<int64_t> shape, int opIdx,

lib/Analysis/Allocation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ class AllocationAnalysis {
152152
auto shapePerCTA = gpu::getAllocationShapePerCTA(allocType);
153153
numElems = product<int64_t>(shapePerCTA);
154154
}
155-
int64_t bytes = numElems * allocType.getElementTypeBitWidth() / 8;
155+
int64_t bytes =
156+
numElems * getIntOrFloatOrPtrBitWidth(allocType.getElementType()) / 8;
156157

157158
auto alignment = alloc.getAlignmentOrDefault();
158159
allocation->addBuffer<BufferT::BufferKind::Explicit>(alloc, bytes,

lib/Analysis/AxisInfo.cpp

Lines changed: 74 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -91,23 +91,26 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {
9191
auto lhsInfo = operands[0]->getValue();
9292
auto rhsInfo = operands[1]->getValue();
9393
auto rank = lhsInfo.getRank();
94+
assert(isa<RankedTensorType>(op.getType()) ||
95+
rank == 1 && "Expected ranked tensor or scalar");
9496
assert(operands.size() == 2 && "Expected two operands");
97+
auto constantValue = getConstantValue(op, lhsInfo, rhsInfo);
98+
if (constantValue.has_value()) {
99+
auto resTy = dyn_cast<RankedTensorType>(op.getType());
100+
AxisInfo::DimVectorT constancy =
101+
resTy ? to_vector(resTy.getShape()) : AxisInfo::DimVectorT(rank, 1);
102+
AxisInfo::DimVectorT contiguity(rank, 1);
103+
AxisInfo::DimVectorT divisibility(
104+
rank, highestPowOf2Divisor<int64_t>(constantValue.value()));
105+
return AxisInfo(contiguity, divisibility, constancy, constantValue);
106+
}
95107
AxisInfo::DimVectorT contiguity;
96108
AxisInfo::DimVectorT divisibility;
97109
AxisInfo::DimVectorT constancy;
98-
auto constantValue = getConstantValue(op, lhsInfo, rhsInfo);
99110
for (auto d = 0; d < rank; ++d) {
100-
if (constantValue.has_value()) {
101-
contiguity.push_back(1);
102-
constancy.push_back(
103-
std::max(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d)));
104-
divisibility.push_back(
105-
highestPowOf2Divisor<int64_t>(constantValue.value()));
106-
} else {
107-
contiguity.push_back(getContiguity(op, lhsInfo, rhsInfo, d));
108-
constancy.push_back(getConstancy(op, lhsInfo, rhsInfo, d));
109-
divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d));
110-
}
111+
contiguity.push_back(getContiguity(op, lhsInfo, rhsInfo, d));
112+
constancy.push_back(getConstancy(op, lhsInfo, rhsInfo, d));
113+
divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d));
111114
}
112115
return AxisInfo(contiguity, divisibility, constancy, constantValue);
113116
}
@@ -125,9 +128,8 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {
125128

126129
virtual int64_t getConstancy(OpTy op, const AxisInfo &lhs,
127130
const AxisInfo &rhs, int dim) {
128-
return 1;
131+
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
129132
}
130-
131133
virtual std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
132134
const AxisInfo &rhs) {
133135
return {};
@@ -192,6 +194,26 @@ class CastOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
192194
}
193195
};
194196

197+
class UnrealizedConversionCastOpAxisInfoVisitor final
198+
: public AxisInfoVisitorImpl<mlir::UnrealizedConversionCastOp> {
199+
public:
200+
using AxisInfoVisitorImpl<
201+
mlir::UnrealizedConversionCastOp>::AxisInfoVisitorImpl;
202+
203+
AxisInfo
204+
getAxisInfo(mlir::UnrealizedConversionCastOp op,
205+
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
206+
auto tensorType = dyn_cast<RankedTensorType>(op.getResultTypes()[0]);
207+
if (tensorType &&
208+
tensorType.getRank() != operands[0]->getValue().getRank()) {
209+
// Do not propagate AxisInfo with incorrect rank. This can cause a crash
210+
// in future visitor applications.
211+
return AxisInfo::getPessimisticValueState(op->getResult(0));
212+
}
213+
return operands[0]->getValue();
214+
}
215+
};
216+
195217
class MakeRangeOpAxisInfoVisitor final
196218
: public AxisInfoVisitorImpl<triton::MakeRangeOp> {
197219
public:
@@ -308,11 +330,6 @@ class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
308330
return gcd(lhs.getDivisibility(dim), rhsDivisibility);
309331
}
310332

311-
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
312-
int dim) override {
313-
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
314-
}
315-
316333
std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
317334
const AxisInfo &rhs) override {
318335
if (lhs.getConstantValue().has_value() &&
@@ -355,11 +372,6 @@ class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::MulIOp> {
355372
return std::max(lhsContiguity, rhsContiguity);
356373
}
357374

358-
int64_t getConstancy(arith::MulIOp op, const AxisInfo &lhs,
359-
const AxisInfo &rhs, int dim) override {
360-
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
361-
}
362-
363375
int64_t getDivisibility(arith::MulIOp op, const AxisInfo &lhs,
364376
const AxisInfo &rhs, int dim) override {
365377
auto lhsDivisibility = lhs.getDivisibility(dim);
@@ -379,9 +391,13 @@ class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::MulIOp> {
379391

380392
std::optional<int64_t> getConstantValue(arith::MulIOp op, const AxisInfo &lhs,
381393
const AxisInfo &rhs) override {
382-
if (lhs.getConstantValue().has_value() &&
383-
rhs.getConstantValue().has_value())
384-
return {lhs.getConstantValue().value() * rhs.getConstantValue().value()};
394+
auto lhsConst = lhs.getConstantValue();
395+
auto rhsConst = rhs.getConstantValue();
396+
if (lhsConst.has_value() && rhsConst.has_value())
397+
return {lhsConst.value() * rhsConst.value()};
398+
if ((lhsConst.has_value() && lhsConst.value() == 0) ||
399+
(rhsConst.has_value() && rhsConst.value() == 0))
400+
return 0;
385401
return {};
386402
}
387403
};
@@ -404,12 +420,11 @@ class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
404420
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
405421
int dim) override {
406422
auto resTy = dyn_cast<RankedTensorType>(op.getType());
423+
auto constancy = BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
407424
if (!resTy)
408-
return BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
425+
return constancy;
409426
auto shape = resTy.getShape();
410-
// Case 1: both lhs and rhs are constants.
411-
auto constancy = gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
412-
// Case 2: lhs contiguous, rhs constant.
427+
// Case: lhs contiguous, rhs constant.
413428
// lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n
414429
// rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p
415430
// lhs / rhs = d_lhs * k / (d_rhs * p), (d_lhs * k + 1) / (d_rhs * p),
@@ -506,15 +521,15 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
506521

507522
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
508523
int dim) override {
524+
auto constancy = BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
509525
auto resTy = dyn_cast<RankedTensorType>(op.getType());
510526
if (!resTy)
511-
return BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
512-
auto shape = resTy.getShape();
513-
// lhs % 1 = 0
514-
return rhs.getConstantValue().has_value() &&
515-
rhs.getConstantValue().value() == 1
516-
? shape[dim]
517-
: gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
527+
return constancy;
528+
// Case: lhs % 1 = 0
529+
if (rhs.getConstantValue().has_value() &&
530+
rhs.getConstantValue().value() == 1)
531+
return resTy.getDimSize(dim);
532+
return constancy;
518533
}
519534

520535
std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
@@ -669,7 +684,7 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
669684
int64_t constHint = 1;
670685
if (lhsInfo.getConstantValue().has_value() &&
671686
rhsInfo.getConstantValue().has_value()) {
672-
constHint = lhsInfo.getConstancy(d);
687+
constHint = shape[d];
673688
constantValue =
674689
compare(getPredicate(op), lhsInfo.getConstantValue().value(),
675690
rhsInfo.getConstantValue().value())
@@ -828,6 +843,13 @@ class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
828843
rhsInfo.getConstantValue().has_value() &&
829844
lhsInfo.getConstantValue() == rhsInfo.getConstantValue())
830845
constantValue = lhsInfo.getConstantValue();
846+
847+
if (constantValue.has_value()) {
848+
auto resTy = dyn_cast<RankedTensorType>(op.getType());
849+
assert(resTy || rank == 1);
850+
constancy =
851+
resTy ? to_vector(resTy.getShape()) : AxisInfo::DimVectorT(rank, 1);
852+
}
831853
}
832854

833855
return AxisInfo(contiguity, divisibility, constancy, constantValue);
@@ -840,11 +862,6 @@ class LogicalOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
840862
using BinaryOpVisitorImpl<OpTy>::BinaryOpVisitorImpl;
841863

842864
private:
843-
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
844-
int dim) override {
845-
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
846-
}
847-
848865
std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
849866
const AxisInfo &rhs) override {
850867
if (lhs.getConstantValue().has_value() &&
@@ -890,11 +907,6 @@ class ShLIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::ShLIOp> {
890907
return multiplyDivisor(lhsDivisibility, 1ll << shift);
891908
}
892909

893-
int64_t getConstancy(arith::ShLIOp op, const AxisInfo &lhs,
894-
const AxisInfo &rhs, int dim) override {
895-
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
896-
}
897-
898910
std::optional<int64_t> getConstantValue(arith::ShLIOp op, const AxisInfo &lhs,
899911
const AxisInfo &rhs) override {
900912
if (lhs.getConstantValue().has_value() &&
@@ -932,11 +944,6 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
932944
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
933945
}
934946

935-
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
936-
int dim) override {
937-
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
938-
}
939-
940947
std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
941948
const AxisInfo &rhs) override {
942949
if (lhs.getConstantValue().has_value() &&
@@ -969,9 +976,15 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
969976
constantValue = {std::min(lhsInfo.getConstantValue().value(),
970977
rhsInfo.getConstantValue().value())};
971978
}
979+
auto resTy = dyn_cast<RankedTensorType>(op.getType());
980+
assert(resTy || rank == 1);
981+
AxisInfo::DimVectorT constancy =
982+
resTy ? to_vector(resTy.getShape()) : AxisInfo::DimVectorT(rank, 1);
983+
AxisInfo::DimVectorT divisibility(
984+
rank, highestPowOf2Divisor<int64_t>(constantValue.value()));
972985
return AxisInfo(/*knownContiguity=*/AxisInfo::DimVectorT(rank, 1),
973-
/*knownDivisibility=*/AxisInfo::DimVectorT(rank, 1),
974-
/*knownConstancy=*/AxisInfo::DimVectorT(rank, 1),
986+
/*knownDivisibility=*/divisibility,
987+
/*knownConstancy=*/constancy,
975988
/*constantValue=*/constantValue);
976989
} else {
977990
AxisInfo::DimVectorT contiguity, divisibility, constancy;
@@ -1029,11 +1042,11 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver,
10291042
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
10301043
// in the process of a PartialConversion, where UnrealizedConversionCast
10311044
// may exist
1045+
visitors.append<UnrealizedConversionCastOpAxisInfoVisitor>();
10321046
visitors.append<CastOpAxisInfoVisitor<arith::ExtSIOp>,
10331047
CastOpAxisInfoVisitor<arith::ExtUIOp>,
10341048
CastOpAxisInfoVisitor<arith::TruncIOp>,
10351049
CastOpAxisInfoVisitor<triton::gpu::ConvertLayoutOp>,
1036-
CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
10371050
CastOpAxisInfoVisitor<triton::BitcastOp>>();
10381051
visitors.append<MakeRangeOpAxisInfoVisitor>();
10391052
visitors.append<PoisonOpAxisInfoVisitor>();
@@ -1384,7 +1397,10 @@ void ModuleAxisInfoAnalysis::update(CallOpInterface callOp,
13841397
callee.setArgAttr(index, attrName, attr);
13851398
};
13861399
auto axisInfo = axisInfoMap->lookup(value);
1387-
assert(axisInfo.getRank() == 1 && "only scalar arguments are supported");
1400+
// Only scalar arguments are supported. Do not forward multi-dimensional
1401+
// AxisInfo to the callee.
1402+
if (axisInfo.getRank() != 1)
1403+
continue;
13881404
setAttrFn("tt.contiguity", axisInfo.getContiguity(0));
13891405
setAttrFn("tt.divisibility", axisInfo.getDivisibility(0));
13901406
setAttrFn("tt.constancy", axisInfo.getConstancy(0));

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ struct ConvertLayoutOpConversion
2626
: public ConvertOpToLLVMPattern<ConvertLayoutOp> {
2727
const TargetInfoBase &targetInfo;
2828

29-
// Set benefit to 2 so that this pattern applies before other convert-layout
30-
// conversions. TODO(jlebar): Eventually we want this to be the only pattern.
3129
explicit ConvertLayoutOpConversion(LLVMTypeConverter &typeConverter,
3230
const TargetInfoBase &targetInfo,
3331
PatternBenefit benefit = 1)
@@ -277,8 +275,7 @@ struct ConvertLayoutOpConversion
277275
StringAttr kReg = str_attr("register");
278276
StringAttr kLane = str_attr("lane");
279277
auto elemTy = getTypeConverter()->convertType(srcTy.getElementType());
280-
int bitwidth =
281-
elemTy.isIntOrFloat() ? elemTy.getIntOrFloatBitWidth() : kPtrBitWidth;
278+
int bitwidth = getIntOrFloatOrPtrBitWidth(elemTy);
282279

283280
auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy, bitwidth);
284281
auto &[pReg, pLane, mixedTranspositions, nPack] = factors;

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ struct ElementwiseInlineAsmOpConversion
276276
auto ty = getTypeConverter()->convertType(getElementType(result));
277277

278278
// Pack return elements into 32-bits.
279-
unsigned bitWidth = ty.isIntOrFloat() ? ty.getIntOrFloatBitWidth() : 64;
279+
unsigned bitWidth = getIntOrFloatOrPtrBitWidth(ty);
280280
unsigned numElemsPerReg =
281281
std::min(std::max(32 / bitWidth, 1u), op.getPackedElement());
282282
assert(op.getPackedElement() % numElemsPerReg == 0);

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ SmallVector<Value> lowerLdSt(
540540
auto kLane = str_attr("lane");
541541
auto kWarp = str_attr("warp");
542542
auto kOffset = str_attr("offset");
543-
auto bitwidth = llvmElemTy.getIntOrFloatBitWidth();
543+
auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy);
544544

545545
auto [elemsPerVec, permutation] =
546546
largestVectorisation(ctx, cvt, bitwidth, maybeMaxVecElems);
@@ -625,7 +625,7 @@ lowerLocalLdSt(Location loc, MLIRContext *ctx,
625625
assert(*cvt.getOutDimNames().begin() == str_attr("offset"));
626626
auto calcPaddedOffset = [&](Value smemOffset) {
627627
TritonLLVMOpBuilder b(loc, rewriter);
628-
auto bitwidth = llvmElemTy.getIntOrFloatBitWidth();
628+
auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy);
629629
if (auto paddedEnc = dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(
630630
srcTy.getEncoding())) {
631631
// Apply the offset needed for padding.

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@ using namespace mlir::triton;
1111
using namespace mlir::triton::gpu;
1212
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
1313
namespace {
14+
15+
Value bitOrPtrCast(Value val, Type type, TritonLLVMOpBuilder &b) {
16+
if (isa<LLVM::LLVMPointerType>(val.getType()) &&
17+
!isa<LLVM::LLVMPointerType>(type)) {
18+
return b.ptrtoint(type, val);
19+
} else {
20+
return b.bitcast(val, type);
21+
}
22+
}
23+
1424
struct SplatOpConversion : public ConvertOpToLLVMPattern<triton::SplatOp> {
1525
using ConvertOpToLLVMPattern<triton::SplatOp>::ConvertOpToLLVMPattern;
1626
// Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a
@@ -39,13 +49,13 @@ struct SplatOpConversion : public ConvertOpToLLVMPattern<triton::SplatOp> {
3949
unsigned ratio = srcBitWidth / cstBitWidth;
4050
Type intTy = IntegerType::get(elemType.getContext(), cstBitWidth);
4151
VectorType vecType = VectorType::get(ratio, intTy);
42-
Value intCst = b.bitcast(constVal, intTy);
52+
Value intCst = bitOrPtrCast(constVal, intTy, b);
4353
Value vec = b.undef(vecType);
4454
for (unsigned i = 0; i < ratio; ++i)
4555
vec = b.insert_element(vecType, vec, intCst, b.int_val(32, i));
4656
constVal = vec;
4757
}
48-
auto llSrc = b.bitcast(constVal, srcType);
58+
Value llSrc = bitOrPtrCast(constVal, srcType, b);
4959
size_t elemsPerThread = getTotalElemsPerThread(tensorTy);
5060
llvm::SmallVector<Value> elems(elemsPerThread, llSrc);
5161
return packLLElements(loc, typeConverter, elems, rewriter, resType);

0 commit comments

Comments
 (0)