Skip to content

Commit 7a5940c

Browse files
jax-triton-devkarupayun
authored andcommitted
OpenXLA-specific changes
1 parent b2de88f commit 7a5940c

File tree

45 files changed

+2187
-114
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+2187
-114
lines changed

BUILD

Lines changed: 900 additions & 0 deletions
Large diffs are not rendered by default.

include/triton/Analysis/Alias.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,9 @@ class SharedMemoryAliasAnalysis
8585
}
8686

8787
/// Computes if the alloc set of the results are changed.
88-
void
89-
visitOperation(Operation *op,
90-
ArrayRef<const dataflow::Lattice<AliasInfo> *> operands,
91-
ArrayRef<dataflow::Lattice<AliasInfo> *> results) override;
88+
LogicalResult visitOperation(
89+
Operation *op, ArrayRef<const dataflow::Lattice<AliasInfo> *> operands,
90+
ArrayRef<dataflow::Lattice<AliasInfo> *> results) override;
9291
};
9392

9493
} // namespace mlir

include/triton/Dialect/Triton/IR/TritonTypes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class TritonTypeDef<string name, string _mnemonic, list<Trait> traits = []>
1515
}
1616

1717
// Floating-point Type
18-
def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
18+
def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
1919
def TT_FloatTensor : RankedTensorOf<[TT_Float]>;
2020
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
2121

include/triton/Dialect/Triton/IR/Utility.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
3131

3232
/// Get the highest power of 2 divisor of an integer.
3333
template <typename T> T highestPowOf2Divisor(T n) {
34-
if (n == 0) {
34+
// When n is 0 or min, return the highest power of 2. The min case is handled
35+
// separately to avoid underflow when T is a signed integer. Technically
36+
// in that case the correct divisor is -n, but this value is outside the
37+
// range of possible values, so we take the next best alternative.
38+
if (n == 0 || n == std::numeric_limits<T>::min()) {
3539
return (static_cast<T>(1) << (sizeof(T) * 8 - 2));
3640
}
3741
return (n & (~(n - 1)));

lib/Analysis/Alias.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ AliasInfo AliasInfo::join(const AliasInfo &lhs, const AliasInfo &rhs) {
2121
return ret;
2222
}
2323

24-
void SharedMemoryAliasAnalysis::visitOperation(
24+
LogicalResult SharedMemoryAliasAnalysis::visitOperation(
2525
Operation *op, ArrayRef<const dataflow::Lattice<AliasInfo> *> operands,
2626
ArrayRef<dataflow::Lattice<AliasInfo> *> results) {
2727
AliasInfo aliasInfo;
@@ -31,7 +31,7 @@ void SharedMemoryAliasAnalysis::visitOperation(
3131
if (auto memdescTy = dyn_cast<triton::MemDescType>(result.getType())) {
3232
if (!isa_and_nonnull<triton::gpu::SharedMemorySpaceAttr>(
3333
memdescTy.getMemorySpace()))
34-
return;
34+
return mlir::success();
3535
}
3636

3737
// Only LocalAllocOp creates a new buffer.
@@ -49,11 +49,13 @@ void SharedMemoryAliasAnalysis::visitOperation(
4949
}
5050

5151
if (pessimistic) {
52-
return setAllToEntryStates(results);
52+
setAllToEntryStates(results);
53+
return mlir::success();
5354
}
5455
// Join all lattice elements
5556
for (auto *result : results)
5657
propagateIfChanged(result, result->join(aliasInfo));
58+
return mlir::success();
5759
}
5860

5961
AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) {

lib/Analysis/AxisInfo.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,9 @@ class AxisInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis<
195195
dataflow::Lattice<AxisInfo>>::getLatticeElement;
196196
using FuncAxisInfoMapT = DenseMap<FunctionOpInterface, AxisInfo>;
197197

198-
void visitOperation(Operation *op,
199-
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
200-
ArrayRef<dataflow::Lattice<AxisInfo> *> results) override;
198+
LogicalResult visitOperation(
199+
Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
200+
ArrayRef<dataflow::Lattice<AxisInfo> *> results) override;
201201
void
202202
visitForOpInductionVar(scf::ForOp op,
203203
ArrayRef<dataflow::Lattice<AxisInfo> *> argLattices);
@@ -1039,7 +1039,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
10391039
visitors.append<LoadOpAxisInfoVisitor>();
10401040
}
10411041

1042-
void AxisInfoAnalysis::visitOperation(
1042+
LogicalResult AxisInfoAnalysis::visitOperation(
10431043
Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
10441044
ArrayRef<dataflow::Lattice<AxisInfo> *> results) {
10451045
// TODO: For sure not the right way to do this
@@ -1048,8 +1048,10 @@ void AxisInfoAnalysis::visitOperation(
10481048
if (op->getValue().getRank() == 0)
10491049
setToEntryState((dataflow::Lattice<AxisInfo> *)op);
10501050
AxisInfo curr = visitors.apply(op, operands);
1051-
if (curr.getRank() == 0)
1052-
return setAllToEntryStates(results);
1051+
if (curr.getRank() == 0) {
1052+
setAllToEntryStates(results);
1053+
return mlir::success();
1054+
}
10531055
// override with hint
10541056
auto newContiguity = curr.getContiguity();
10551057
auto newDivisibility = curr.getDivisibility();
@@ -1071,6 +1073,7 @@ void AxisInfoAnalysis::visitOperation(
10711073
// join all lattice elements
10721074
for (auto *result : results)
10731075
propagateIfChanged(result, result->join(curr));
1076+
return mlir::success();
10741077
}
10751078

10761079
void AxisInfoAnalysis::visitForOpInductionVar(

lib/Analysis/Utility.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ bool supportMFMATypes(Type a, Type b) {
425425
if (a.getIntOrFloatBitWidth() != b.getIntOrFloatBitWidth())
426426
return false;
427427

428+
auto F8E4M3FN = TypeID::get<Float8E4M3FNType>();
428429
auto F8E5M2 = TypeID::get<Float8E5M2Type>();
429430
auto F8E4M3FNUZ = TypeID::get<Float8E4M3FNUZType>();
430431
auto F8E5M2FNUZ = TypeID::get<Float8E5M2FNUZType>();
@@ -436,6 +437,7 @@ bool supportMFMATypes(Type a, Type b) {
436437
{F32, F32},
437438
{F16, F16},
438439
{BF16, BF16},
440+
{F8E4M3FN, F8E4M3FN},
439441
{F8E5M2, F8E5M2},
440442
{F8E4M3FNUZ, F8E4M3FNUZ},
441443
{F8E4M3FNUZ, F8E5M2FNUZ},
@@ -495,14 +497,14 @@ bool supportMMA(triton::DotOp op, int version) {
495497
return false;
496498
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
497499
retShapePerCTA[rank - 1] % 8 == 0 &&
498-
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ() ||
500+
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() ||
499501
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
500502
aElemTy.isF32()))) {
501503
return false;
502504
}
503505
// We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op.
504506
if (op.getMaxNumImpreciseAcc() < 32 &&
505-
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ()) &&
507+
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) &&
506508
cast<RankedTensorType>(op.getType()).getElementType().isF32()) {
507509
return false;
508510
}

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
4040
auto ouEltTy = ouTensorTy.getElementType();
4141
if (inBitWidth == ouBitWidth)
4242
return values;
43-
if (inBitWidth == 16 && ouBitWidth == 32) {
43+
if ((inBitWidth == 16 && ouBitWidth == 32) ||
44+
(inBitWidth == 32 && ouBitWidth == 16)) {
4445
SmallVector<Value> ret;
4546
for (unsigned i = 0; i < values.size(); i += 8) {
4647
ret.push_back(values[i]);

lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
3434
addConversion([&](mlir::Float8E4M3FNUZType type) -> std::optional<Type> {
3535
return IntegerType::get(type.getContext(), 8);
3636
});
37+
addConversion([&](mlir::Float8E4M3FNType type) -> std::optional<Type> {
38+
return IntegerType::get(type.getContext(), 8);
39+
});
3740
addConversion([&](mlir::Float8E5M2Type type) -> std::optional<Type> {
3841
return IntegerType::get(type.getContext(), 8);
3942
});

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,9 @@ struct ArithConstantSplatOpConversion
8787
// LLVM IR.
8888
if (type::isFloat8(elemType))
8989
elemType = rewriter.getIntegerType(8);
90-
auto constOp = rewriter.create<LLVM::ConstantOp>(loc, elemType, val);
9190
auto typeConverter = getTypeConverter();
91+
auto constOp = rewriter.create<LLVM::ConstantOp>(
92+
loc, typeConverter->convertType(elemType), val);
9293
auto llStruct = SplatOpConversion::convertSplatLikeOp(
9394
elemType, op.getType(), constOp, typeConverter, rewriter, loc);
9495
rewriter.replaceOp(op, llStruct);

0 commit comments

Comments
 (0)