Skip to content

Commit c0987c7

Browse files
authored
Reland upstream commit 9e90089 (#2617)
Closes #2527 Please do not squash and merge this PR.
2 parents 16b2057 + 19d3ed5 commit c0987c7

File tree

14 files changed

+185
-116
lines changed

14 files changed

+185
-116
lines changed

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -246,13 +246,6 @@ chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
246246
ArrayRef<unsigned> repShape,
247247
ArrayRef<unsigned> paddedRepShape,
248248
ArrayRef<unsigned> order, int swizzleByteSize);
249-
250-
// FIXME
251-
// Exposing to use it in LinearLayoutConversionsTest.cpp
252-
// Remove it once we fully activate the DotOperand conversion via LLs
253-
class DotOperandEncodingAttr;
254-
LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
255-
DotOperandEncodingAttr dot);
256249
} // namespace mlir::triton::gpu
257250

258251
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

lib/Analysis/Allocation.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,12 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
115115

116116
assert(!isMfmaToDotShortcut(srcTy, dstTy));
117117

118-
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
118+
// FIXME This is NOT entirely correct
119+
// This should be getElemOrder, but we don't have such a method
120+
// TODO Implement getElemOrder and make sure it's consistent with
121+
// getContigPerThread
122+
auto inOrd = gpu::getThreadOrder(srcLayout);
123+
auto outOrd = gpu::getThreadOrder(dstLayout);
119124
scratchConfig.order = outOrd;
120125

121126
unsigned srcContigPerThread =

lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,16 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) {
9090
auto dstDotOp =
9191
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
9292
if (srcBlocked && dstDotOp) {
93+
// FIXME [Dot LL]
94+
// We support this one via LLs, as the LocalLoad path is buggy
95+
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dstDotOp.getParent())) {
96+
bool largeKWidth =
97+
dstDotOp.getKWidth() * dstType.getElementTypeBitWidth() > 64;
98+
if (mma.isAmpere() && largeKWidth) {
99+
return;
100+
}
101+
}
102+
93103
Attribute sharedMemorySpace =
94104
triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext());
95105
auto tmpType = MemDescType::get(

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/Support/LLVM.h"
1212
#include "triton/Analysis/Utility.h"
1313
#include "triton/Dialect/Triton/IR/Utility.h"
14+
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
1415
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1516
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
1617
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
@@ -237,13 +238,36 @@ static SmallVector<unsigned> eraseOrder(ArrayRef<unsigned> order,
237238
return resOrder;
238239
}
239240

241+
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
242+
bool kMajor) {
243+
// kMajor: if true, the matrix is fastest-running on k,
244+
// otherwise it is on m (resp. n)
245+
// opIdx=0: [batch, m, k] if rank == 3 else [m, k]
246+
// opIdx=1: [batch, k, n] if rank == 3 else [k, n]
247+
// batch (if rank == 3) is always the slowest running dimension
248+
assert(rank == 2 || rank == 3);
249+
assert(opIdx == 0 || opIdx == 1);
250+
SmallVector<unsigned> order(rank);
251+
std::iota(order.rbegin(), order.rend(), 0);
252+
// If opIdx is 1 and kMajor is true, the order is [0, 1]
253+
// (resp. [1, 2, 0] if rank == 3)
254+
// Same if opIdx is 0 and kMajor is false
255+
if (bool(opIdx) == kMajor) {
256+
std::swap(order[0], order[1]);
257+
}
258+
return order;
259+
}
260+
240261
SmallVector<unsigned> getWarpOrder(Attribute layout) {
241262
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
242263
if (isa<AMDMfmaEncodingAttr>(dotLayout.getParent())) {
243264
return getWarpOrder(dotLayout.getParent());
244265
}
245266
}
246267
auto order = getOrder(layout);
268+
// FIXME: This mmaLayout if should just return
269+
// getOrderForDotOperand(0, order.size(), kMajor=false)
270+
// as mma has the same order as DotOperand(opIdx=0)
247271
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
248272
if (mmaLayout.isHopper()) {
249273
// Hopper MMA instructions force a warp order of [0, 1]. See docs:
@@ -253,40 +277,8 @@ SmallVector<unsigned> getWarpOrder(Attribute layout) {
253277
order.insert(order.begin(), 0);
254278
}
255279
} else if (auto dotOpLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
256-
// opIdx=0: [/*dim0*/batch, /*dim1=*/m, /*dim2=*/k] -> order=[1, 2, 0]
257-
// opIdx=1: [/*dim0*/batch, /*dim1=*/k, /*dim2=*/n] -> order=[2, 1, 0]
258-
std::iota(order.rbegin(), order.rend(), 0);
259-
if (dotOpLayout.getOpIdx() == 0) {
260-
std::swap(order[0], order[1]);
261-
}
262-
}
263-
return order;
264-
}
265-
266-
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank) {
267-
assert((rank == 2 || rank == 3) &&
268-
"Invalid rank for dot operand order computation");
269-
SmallVector<unsigned> order(rank);
270-
// The 'order' field typically represents a descending sorted array of
271-
// dimensions based on contiguity. For instance, in axisInfo utilities that
272-
// retrieve tensor contiguity, it's assumed that the dimension with the
273-
// highest contiguity corresponds to order[0].
274-
//
275-
// The relation between contiguity and order is only relevant if the layout
276-
// interfaces with HBM, as is the case when we load tensor from HBM to
277-
// registers in the dot layout to bypass LDS. When bypassing LDS, we make
278-
// the following assumptions about tensor layouts:
279-
// - Tensor A (opIdx == 0) is considered to be row-major.
280-
// - Tensor B (opIdx == 1) is considered to be column-major.
281-
//
282-
// Based on these assumptions, we define the following orders:
283-
// - For opIdx == 0, batch=dim0, m=dim1, and k=dim2, we assume an order of [2,
284-
// 1, 0] for 3D tensors.
285-
// - For opIdx == 1, batch=dim0, k=dim1, and n=dim2, we assume an order of [1,
286-
// 2, 0] for 3D tensors.
287-
std::iota(order.rbegin(), order.rend(), 0);
288-
if (opIdx == 1) {
289-
std::swap(order[0], order[1]);
280+
order = getOrderForDotOperand(dotOpLayout.getOpIdx(), order.size(),
281+
/*kMajor*/ false);
290282
}
291283
return order;
292284
}
@@ -303,7 +295,7 @@ SmallVector<unsigned> getOrder(Attribute layout) {
303295
return order;
304296
}
305297
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
306-
auto rank = getWarpsPerCTA(dotLayout.getParent()).size();
298+
auto rank = dotLayout.getWarpsPerCTA().size();
307299
// FIXME: delete if branch for `DpasEncodingAttr` and provide more
308300
// general solution to make `getOrderForDotOperand` function compatible
309301
// with Intel layouts.
@@ -314,7 +306,7 @@ SmallVector<unsigned> getOrder(Attribute layout) {
314306
std::iota(order.rbegin(), order.rend(), 0);
315307
return order;
316308
}
317-
return getOrderForDotOperand(dotLayout.getOpIdx(), rank);
309+
return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true);
318310
}
319311
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
320312
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
@@ -1069,7 +1061,17 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
10691061
return ::getWarpOrder(*this);
10701062
}
10711063
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
1072-
return ::getOrder(*this);
1064+
// FIXME: delete if branch for `DpasEncodingAttr` and provide more
1065+
// general solution to make `getOrderForDotOperand` function compatible
1066+
// with Intel layouts.
1067+
// More details:
1068+
// https://github.com/intel/intel-xpu-backend-for-triton/pull/2517
1069+
if (mlir::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
1070+
return ::getOrder(*this);
1071+
} else {
1072+
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
1073+
/*kMajor*/ true);
1074+
}
10731075
}
10741076
SmallVector<unsigned> DotOperandEncodingAttr::getShapePerCTATile(
10751077
ArrayRef<int64_t> tensorShape) const {
@@ -2055,6 +2057,7 @@ SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2RepForOperand(
20552057
ArrayRef<int64_t> shape, int bitwidth, int kWidth, int opIdx) const {
20562058
auto rank = shape.size();
20572059
auto warpsPerCTA = getWarpsPerCTA();
2060+
20582061
SmallVector<int> shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth};
20592062
int numRepBatch =
20602063
rank == 3

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -906,17 +906,17 @@ std::optional<LinearLayout>
906906
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
907907
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(getParent())) {
908908
return dotOperandMfmaToLinearLayout(*this, shape);
909-
}
910-
if (auto dpasLayout = llvm::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
909+
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
910+
// FIXME [Dot LL]
911+
// Do this unconditionally
912+
auto largeKWidth = getKWidth() == 8;
913+
if (mma.isAmpere() && largeKWidth) {
914+
return ampereDotToLinearLayout(shape, *this);
915+
}
916+
} else if (auto dpasLayout =
917+
llvm::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
911918
return dotOperandDpasToLinearLayout(*this, shape);
912919
}
913-
914-
// TODO Activate in a follow-up PR
915-
// else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
916-
// if (mma.isAmpere()) {
917-
// return ampereDotToLinearLayout(shape, *this);
918-
// }
919-
//}
920920
return std::nullopt;
921921
}
922922

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ LogicalResult UpcastMXFPOp::verify() {
1717
auto xTy = getSrc().getType();
1818
auto scaleTy = getScale().getType();
1919

20-
if (xTy.getElementType() != FloatType::getBF16(getContext())) {
21-
return emitOpError("element type of the first operand must be bf16");
20+
if (xTy.getElementType() != FloatType::getBF16(getContext()) &&
21+
xTy.getElementType() != IntegerType::get(getContext(), 8)) {
22+
return emitOpError("element type of the first operand must be bf16 or i8");
2223
}
2324

2425
if (scaleTy.getElementType() != IntegerType::get(getContext(), 8)) {
@@ -72,7 +73,7 @@ LogicalResult UpcastMXFPOp::verify() {
7273
}
7374

7475
LogicalResult UpcastMXFPOp::inferReturnTypes(
75-
MLIRContext *context, std::optional<Location> location, ValueRange operands,
76+
MLIRContext *ctx, std::optional<Location> loc, ValueRange operands,
7677
DictionaryAttr attributes, OpaqueProperties opaqueProperties,
7778
RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
7879
auto xTy = cast<RankedTensorType>(operands[0].getType());
@@ -82,21 +83,25 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
8283

8384
auto encoding = xTy.getEncoding();
8485
if (!encoding) {
85-
return emitOptionalError(location, "expected an encoding");
86+
return emitOptionalError(loc, "expected an encoding");
8687
}
8788
if (!mlir::isa<DotOperandEncodingAttr>(encoding)) {
88-
return emitOptionalError(location, "expected an mma layout encoding");
89-
}
90-
if (xShape.size() < 2) {
91-
return emitOptionalError(location, "tensor rank must be at least 2");
89+
return emitOptionalError(loc, "expected a dotOperand encoding");
9290
}
9391

94-
// For now we just return the input encoding. For fp4 we'll need to cast from
95-
// tf32 to fp16 encoding and multiply the shape by two
96-
assert((typeEncoded == F8F6F4Type::E4M3 || typeEncoded == F8F6F4Type::E5M2) &&
97-
"NYI: only fp8e4m3 and fp8e5m2 are supported");
92+
if (typeEncoded == F8F6F4Type::E2M1) {
93+
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
94+
auto newVEncoding = DotOperandEncodingAttr::get(
95+
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
96+
oldEncoding.getKWidth() * 2);
97+
auto newShape = SmallVector<int64_t>(xShape);
98+
newShape.back() *= 2;
99+
inferredReturnTypes.push_back(
100+
RankedTensorType::get(newShape, FloatType::getBF16(ctx), newVEncoding));
101+
} else {
102+
inferredReturnTypes.push_back(xTy);
103+
}
98104

99-
inferredReturnTypes.push_back(xTy);
100105
return success();
101106
}
102107

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ class ScaledBlockedToMMAv2
406406
auto ctx = dotOp.getContext();
407407

408408
// Check that rhs scale is null
409-
assert(dotOp.getRhsScale() == nullptr && "rhs scale must be null");
409+
assert(dotOp.getRhsScale() == nullptr && "rhs scale NYI");
410410

411411
// operands
412412
auto a = dotOp.getLhs();
@@ -426,10 +426,11 @@ class ScaledBlockedToMMAv2
426426
}
427427
};
428428

429-
assert(aType == F8F6F4Type::E4M3 ||
430-
aType == F8F6F4Type::E5M2 && "lhs just supports fp8");
429+
assert((aType == F8F6F4Type::E4M3 || aType == F8F6F4Type::E5M2 ||
430+
aType == F8F6F4Type::E2M1) &&
431+
"NYI: lhs supports fp4 or fp8");
431432
assert(bType == F8F6F4Type::E4M3 ||
432-
bType == F8F6F4Type::E5M2 && "rhs just supports fp8");
433+
bType == F8F6F4Type::E5M2 && "NYI: rhs supports fp8");
433434

434435
// TODO run accelerate matmul on A and B first to choose their layouts
435436
// Set return type
@@ -440,6 +441,7 @@ class ScaledBlockedToMMAv2
440441
auto instrShape = mmaVersionToInstrShape(versionMajor, retShapePerCTA,
441442
rewriter.getBF16Type(), numWarps);
442443
auto CTALayout = getCTALayout(oldRetType.getEncoding());
444+
// TODO Use warpsPerTileV2
443445
SmallVector<unsigned> warpsPerCTA = {numWarps, 1};
444446
auto mmaEnc = NvidiaMmaEncodingAttr::get(ctx, /*versionMajor=*/versionMajor,
445447
/*versionMinor=*/0, warpsPerCTA,
@@ -452,27 +454,39 @@ class ScaledBlockedToMMAv2
452454
auto newAcc =
453455
rewriter.create<ConvertLayoutOp>(oldAcc.getLoc(), newRetType, oldAcc);
454456

455-
auto toMMABf16 = [&newRetType, &rewriter, &ctx,
456-
&enumToType](TypedValue<RankedTensorType> v, int idx,
457-
F8F6F4Type type) {
458-
// MMAv2 Layout
457+
auto toMMABf16 = [&newRetType, &rewriter, &ctx, &enumToType](
458+
TypedValue<RankedTensorType> v, int idx,
459+
F8F6F4Type type) -> TypedValue<RankedTensorType> {
459460
auto vType = v.getType();
460-
auto newVEncoding = DotOperandEncodingAttr::get(
461-
ctx, idx, newRetType.getEncoding(), enumToType((type)));
462-
auto newVType = RankedTensorType::get(
463-
v.getType().getShape(), v.getType().getElementType(), newVEncoding);
464-
v = rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);
465-
466-
// Bitcast
467-
auto vTypeFp8 = RankedTensorType::get(
468-
vType.getShape(), rewriter.getFloat8E4M3FNType(), newVEncoding);
469-
v = cast<TypedValue<RankedTensorType>>(
470-
rewriter.create<BitcastOp>(v.getLoc(), vTypeFp8, v).getResult());
471-
472-
// Convert to bf16
473-
auto vTypeBf16 = RankedTensorType::get(
474-
vType.getShape(), rewriter.getBF16Type(), newVEncoding);
475-
return rewriter.create<FpToFpOp>(v.getLoc(), vTypeBf16, v);
461+
if (type == F8F6F4Type::E2M1) {
462+
// A bit too dynamically typed...
463+
// perhaps return ints in both cases?
464+
465+
auto retEnc = dyn_cast<NvidiaMmaEncodingAttr>(newRetType.getEncoding());
466+
auto newVEncoding = DotOperandEncodingAttr::get(
467+
ctx, idx, newRetType.getEncoding(), /*kWidth=*/4);
468+
auto newVType = RankedTensorType::get(
469+
vType.getShape(), vType.getElementType(), newVEncoding);
470+
return rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);
471+
} else {
472+
assert(type == F8F6F4Type::E5M2 || type == F8F6F4Type::E4M3);
473+
auto newVEncoding = DotOperandEncodingAttr::get(
474+
ctx, idx, newRetType.getEncoding(), /*kWidth=*/8);
475+
auto newVType = RankedTensorType::get(
476+
vType.getShape(), vType.getElementType(), newVEncoding);
477+
v = rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);
478+
479+
// Bitcast
480+
auto vTypeFp8 = RankedTensorType::get(vType.getShape(),
481+
enumToType(type), newVEncoding);
482+
v = cast<TypedValue<RankedTensorType>>(
483+
rewriter.create<BitcastOp>(v.getLoc(), vTypeFp8, v).getResult());
484+
485+
// Convert to bf16
486+
auto vTypeBf16 = RankedTensorType::get(
487+
vType.getShape(), rewriter.getBF16Type(), newVEncoding);
488+
return rewriter.create<FpToFpOp>(v.getLoc(), vTypeBf16, v);
489+
}
476490
};
477491
a = toMMABf16(a, 0, aType);
478492
b = toMMABf16(b, 1, bType);

lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ class TritonGPUReduceDataDuplicationPass
4444
return;
4545
if (!cvtNeedsSharedMemory(srcType, dstType))
4646
return;
47+
// FIXME [Dot LL]
48+
// We support this one via LLs, as the LocalLoad path is buggy
49+
bool largeKWidth =
50+
dstDotOp.getKWidth() * dstType.getElementTypeBitWidth() > 64;
51+
if (largeKWidth) {
52+
return;
53+
}
4754
auto srcOrder = triton::gpu::getOrder(srcEncoding);
4855
auto rank = srcOrder.size();
4956
SmallVector<unsigned> sharedOrder;

python/test/unit/language/test_core.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3475,10 +3475,9 @@ def mxfp_to_bf16_kernel(
34753475
tl.store(mxfp_ptr + offsets, tl.ravel(mxfp), mask=offsets < N * 32)
34763476

34773477
def dot_scale_ref(x, scale, y, type_x, type_y):
3478-
e_bits, m_bits = {"e4m3": (4, 3), "e5m2": (5, 2)}[type_x]
3478+
e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type_x]
34793479
type_fp8_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2}[type_y]
34803480

3481-
# Need to implement fp4 -> fp8 cast to support 1 byte types
34823481
comp_dtype = torch.bfloat16
34833482

34843483
x = x.contiguous()

0 commit comments

Comments
 (0)