Skip to content

Commit 7c407a3

Browse files
committed
OpenXLA-specific changes
1 parent 905232b commit 7c407a3

File tree

40 files changed

+2426
-149
lines changed

40 files changed

+2426
-149
lines changed

BUILD

Lines changed: 910 additions & 0 deletions
Large diffs are not rendered by default.

cmake/llvm-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
fa57c7a6a5f594a9e3ae2dbe3542cf89a20cdd73
1+
bef3b54ea10a564a2de72f658f2efd64f537c079

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,6 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
108108
let assemblyFormat = "$src attr-dict (`,` `rounding` `=` $rounding^)? `:` type($src) `->` type($result)";
109109

110110
let hasVerifier = 1;
111-
112-
let hasFolder = 1;
113111
}
114112

115113
//

lib/Analysis/AxisInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -937,7 +937,7 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
937937
// Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
938938
lhsDivisibility = 1;
939939
}
940-
return std::max<int64_t>(1, lhsDivisibility / (1 << shift));
940+
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
941941
}
942942

943943
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
5959
auto ouEltTy = ouTensorTy.getElementType();
6060
if (inBitWidth == ouBitWidth)
6161
return values;
62-
if (inBitWidth == 16 && ouBitWidth == 32) {
62+
if ((inBitWidth == 16 && ouBitWidth == 32) ||
63+
(inBitWidth == 32 && ouBitWidth == 16)) {
6364
// Register layout conversion:
6465
//
6566
// [0, 1], [4, 5] ⟶ [0], [1], [4], [5]
@@ -85,7 +86,8 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
8586
}
8687
return ret;
8788
}
88-
if (inBitWidth == 8 && ouBitWidth == 16) {
89+
if ((inBitWidth == 8 && ouBitWidth == 16) ||
90+
(inBitWidth == 16 && ouBitWidth == 8)) {
8991
// Register layout conversion:
9092
//
9193
// [0, 1, 2, 3], [8, 9, 10, 11] ⟶ [0, 1], [2, 3], [8, 9], [10, 11]

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -728,29 +728,6 @@ LogicalResult ReshapeOp::verify() {
728728
}
729729

730730
//-- FpToFpOp --
731-
732-
// Fold FpToFpOp when the input operand is a constant zero.
733-
OpFoldResult FpToFpOp::fold(FoldAdaptor adaptor) {
734-
auto srcVal = getSrc();
735-
auto dstTy = getType();
736-
737-
const llvm::fltSemantics &semantic =
738-
llvm::cast<FloatType>(dstTy.getElementType()).getFloatSemantics();
739-
740-
if (matchPattern(srcVal, m_PosZeroFloat())) {
741-
llvm::APFloat posZero =
742-
llvm::APFloat::getZero(semantic, /*negative=*/false);
743-
return DenseFPElementsAttr::get(dstTy, posZero);
744-
}
745-
746-
if (matchPattern(srcVal, m_NegZeroFloat())) {
747-
llvm::APFloat negZero = llvm::APFloat::getZero(semantic, /*negative=*/true);
748-
return DenseFPElementsAttr::get(dstTy, negZero);
749-
}
750-
751-
return {};
752-
}
753-
754731
LogicalResult FpToFpOp::verify() {
755732
auto dstType = getType().getElementType();
756733
auto srcType = getSrc().getType().getElementType();

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2801,6 +2801,11 @@ struct CanonicalizeConvertFromAlloc
28012801
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
28022802
if (!convert)
28032803
return failure();
2804+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
2805+
// to SharedEncoding, so we want to keep this layout conversion.
2806+
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
2807+
convert.getSrc().getType().getEncoding()))
2808+
return failure();
28042809
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
28052810
op, op->getResult(0).getType(), convert.getSrc());
28062811
return mlir::success();

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,21 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
162162
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
163163
newLayout, SharedMemorySpace);
164164
rewriter.setInsertionPointAfterValue(arg);
165+
166+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
167+
// to SharedEncoding.
168+
if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
169+
argType.getEncoding())) {
170+
// Create a layout conversion from DotOperandEncoding to BlockedEncoding
171+
// then pass it to the LocalAllocOp.
172+
auto newArgType = RankedTensorType::get(
173+
argType.getShape(), argType.getElementType(), dotOpEnc.getParent());
174+
auto dotOperandToBlockedCvt =
175+
rewriter.create<ConvertLayoutOp>(arg.getLoc(), newArgType, arg);
176+
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType,
177+
dotOperandToBlockedCvt);
178+
}
179+
165180
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
166181
}
167182

@@ -171,6 +186,15 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
171186
mutable llvm::DenseMap<Operation *, unsigned> dotOpInstNs;
172187

173188
static bool bwdFilter(Operation *op) {
189+
// Dot operand layout assignment to Predicates are not currently supported
190+
// during lowering from TritonGPU to LLVM in Triton for MMA cases. This
191+
// condition limits visibility of the original bit-width so that predicate
192+
// are not considered, hence, kwidth can never be = 32.
193+
if (isa<arith::UIToFPOp>(op)) {
194+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
195+
if (srcType.isInteger(1))
196+
return false;
197+
}
174198
return op->getNumOperands() == 1 &&
175199
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
176200
isPureUnaryInlineAsm(op) ||

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
111111
PatternRewriter &rewriter) const override {
112112
// Only consider conversions to dot operand.
113113
auto cvtTy = cast<RankedTensorType>(cvt.getType());
114-
if (!isa<DotOperandEncodingAttr>(cvtTy.getEncoding()))
114+
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(cvtTy.getEncoding());
115+
if (!dotOpEnc)
115116
return failure();
116117

117118
auto src = cvt.getSrc().getDefiningOp();
@@ -126,6 +127,12 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
126127
[](Type ty) { return isa<RankedTensorType>(ty); }))
127128
return failure();
128129

130+
// Quick handling to fix loading issues when computing the original
131+
// bitwidth is unable to realize that there is a mixed-precision dot
132+
// (hence kWidth = 1) but wants to hoist through the type conversion.
133+
if (isa<arith::ExtFOp>(src) && dotOpEnc.getKWidth() == 1)
134+
return failure();
135+
129136
// Only consider custom conversions or arith ops.
130137
// TODO(jlebar): Is this too restrictive?
131138
if (!isa<FpToFpOp, BitcastOp>(src) && !isPureUnaryInlineAsm(src) &&
@@ -138,6 +145,14 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
138145
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(src))
139146
return failure();
140147

148+
// Don't hoist through u1 -> fp casts as they aren't supported in
149+
// ElementwiseOpToLLVM::reorderValues().
150+
if (isa<arith::UIToFPOp>(src)) {
151+
Type srcType = getElementTypeOrSelf(src->getOperand(0));
152+
if (srcType.isInteger(1))
153+
return failure();
154+
}
155+
141156
// Check that the conversion is transitively dependent on a load, and all
142157
// operations between the load and the conversion are layout preserving.
143158
//

lib/Dialect/TritonGPU/Transforms/Prefetch.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
116116
// opIdx: 0 => a, 1 => b
117117
auto type = cast<triton::MemDescType>(v.getType());
118118
SmallVector<int64_t> shape{type.getShape().begin(), type.getShape().end()};
119-
SmallVector<int64_t> offset{0, 0};
119+
SmallVector<int64_t> offset(shape.size(), 0);
120120
Type elementType = type.getElementType();
121121

122122
// k => (prefetchWidth, k - prefetchWidth)
@@ -140,8 +140,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
140140
type.getMemorySpace()),
141141
v, offsetsVal);
142142

143+
// We need to assign kwidth to zero in the case where the parent layout is
144+
// Blocked, otherwise the verifier emits a failure. The parent layout is
145+
// Blocked only when Tensor Cores are disabled.
146+
int kwidth = dyn_cast<triton::gpu::BlockedEncodingAttr>(dotEncoding)
147+
? 0
148+
: prefetchWidth / 8;
143149
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
144-
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
150+
builder.getContext(), opIdx, dotEncoding, kwidth);
145151
Value prefetchSlice = builder.create<triton::gpu::LocalLoadOp>(
146152
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
147153
newSmem);
@@ -190,6 +196,22 @@ LogicalResult Prefetcher::initialize() {
190196
break;
191197
if (!op->getResult(0).hasOneUse())
192198
break;
199+
// Similar to issues faced in HoistLayoutConversion pattern in
200+
// OptimizeDotOperands.cpp, we can't propagate through type casts from
201+
// predicates as they aren't supported in Triton when encoded with dot_op
202+
// layout.
203+
if (isa<arith::UIToFPOp>(op)) {
204+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
205+
if (srcType.isInteger(1))
206+
break;
207+
}
208+
// Propagation through ExpandDims is currently not supported. This blindly
209+
// replaces the encoding with dot encoding & but ExpandDims requires a
210+
// SliceEncoding. This could be rewritten to support it somehow, but I
211+
// don't think it's trivial & it's currently crashing.
212+
if (isa<ExpandDimsOp>(op)) {
213+
break;
214+
}
193215
rets.push_back(op->getOperand(0));
194216
if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
195217
foundConvertFromShared = true;

0 commit comments

Comments
 (0)