Skip to content

Commit 0111436

Browse files
Merge commit 'c186592a17299439900d712e85556e8578345821'
2 parents f3a0aec + c186592 commit 0111436

File tree

66 files changed

+1914
-1559
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+1914
-1559
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,7 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
732732
`lhs` `=` $a_elem_type `rhs` `=` $b_elem_type attr-dict
733733
`:` type($a) (`,` type($a_scale)^)? `*` type($b) (`,` type($b_scale)^)? `->` type($d)
734734
}];
735+
let hasVerifier = 1;
735736
}
736737

737738
//

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -117,19 +117,6 @@ chooseDsReadTrLayout(Attribute enc, ArrayRef<int64_t> shape,
117117
int32_t elemBitWidth, unsigned instBitWidth,
118118
unsigned numLanesInShuffleGroup);
119119

120-
LinearLayout getScaleTMEMStoreLinearLayout(RankedTensorType scaleType,
121-
int numWarps);
122-
123-
std::optional<LinearLayout>
124-
getTmemLoadStoreLayout16x256(int M, int N, RankedTensorType oldType,
125-
int numWarps);
126-
127-
// Return a layout valid for TMemLoad op for a tmem layout of block MxN that
128-
// distribute the data long M for the warp groups. This doesn't affect the TMem
129-
// layout it just returns a distributed layout compatible for tmem_load.
130-
LinearLayout getTmemLoadLayoutSplitLongM(int M, int N, RankedTensorType oldType,
131-
int numWarps);
132-
133120
// Create LinearLayout for scale in scaled mfma.
134121
LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
135122
ArrayRef<int64_t> dotOperandShape,
@@ -161,5 +148,15 @@ std::optional<LinearLayout> chooseMfmaLikeStoreLayout(RankedTensorType valType);
161148
LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared,
162149
bool disableSwizzle);
163150

151+
// Make a LinearLayout that maps a block-id to an N-dimensional index.
152+
//
153+
// The tensor is split up into CTAsPerCGA pieces, which are distributed among
154+
// the CTAsPerCGA CTAs (i.e. blocks) in the CGA (i.e. groups).
155+
//
156+
// See the nomenclature note at the top of the LinearLayoutConversions.cpp file
157+
// for an explanation of why this is called makeCgaLayout when it accepts a
158+
// CTALayoutAttr.
159+
LinearLayout makeCgaLayout(CTALayoutAttr layout);
160+
164161
} // namespace mlir::triton::gpu
165162
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "mlir/IR/BuiltinOps.h"
3030
#include "mlir/IR/BuiltinTypes.h"
3131
#include "mlir/IR/Dialect.h"
32+
#include "llvm/Support/ErrorHandling.h"
3233

3334
// TritonNvidiaGPU depends on Triton
3435
#include "triton/Dialect/Triton/IR/Dialect.h"
@@ -61,24 +62,68 @@ struct TMemAllocation {
6162
int numCols;
6263
};
6364

65+
// Used to describe the layout of the TMEM load/store instructions
66+
enum class TMemAccessAtom { I32x32b, I16x64b, I16x128b, I16x256b, I16x32bx2 };
67+
68+
inline int getElementsPerThread(TMemAccessAtom atom) {
69+
switch (atom) {
70+
case TMemAccessAtom::I32x32b:
71+
case TMemAccessAtom::I16x64b:
72+
case TMemAccessAtom::I16x32bx2:
73+
return 1;
74+
case TMemAccessAtom::I16x128b:
75+
return 2;
76+
case TMemAccessAtom::I16x256b:
77+
return 4;
78+
}
79+
llvm_unreachable("Unknown TMemAccessAtom");
80+
}
81+
82+
inline const char *getOpShape(TMemAccessAtom atom) {
83+
switch (atom) {
84+
case TMemAccessAtom::I32x32b:
85+
return "32x32b";
86+
case TMemAccessAtom::I16x64b:
87+
return "16x64b";
88+
case TMemAccessAtom::I16x128b:
89+
return "16x128b";
90+
case TMemAccessAtom::I16x256b:
91+
return "16x256b";
92+
case TMemAccessAtom::I16x32bx2:
93+
return "16x32bx2";
94+
}
95+
llvm_unreachable("Unknown TMemAccessAtom");
96+
}
97+
98+
LinearLayout getTileLayout(MLIRContext *ctx, TMemAccessAtom atom,
99+
bool unpacked);
100+
64101
TMemAllocation getTmemAllocSizes(gpu::MemDescType memDescType);
65102

66-
gpu::DistributedEncodingTrait getTmemCompatibleLayout(unsigned M, unsigned N,
67-
RankedTensorType oltType,
68-
unsigned numWarps);
69-
gpu::DistributedEncodingTrait
103+
SmallVector<gpu::DistributedEncodingTrait>
104+
getTmemCompatibleLayouts(gpu::MemDescType memType, unsigned numWarps,
105+
ArrayRef<int64_t> ctaSplit = {1, 1});
106+
107+
std::optional<gpu::DistributedEncodingTrait>
70108
getTmemLoadLayoutSplitLongM(RankedTensorType tensorType,
71109
gpu::MemDescType memType, int numWarps);
110+
72111
SmallVector<gpu::DistributedEncodingTrait>
73112
getTmemCompatibleLayouts(Operation *op, RankedTensorType tensorType,
74113
gpu::MemDescType memType);
75114

76115
bool isDistributedLayoutTMemCompatible(Operation *op,
77116
RankedTensorType tensorType,
78117
gpu::MemDescType memType);
79-
bool isDistributedLayoutSplitMTmemLoadStore(RankedTensorType tensorType,
80-
gpu::MemDescType memType,
81-
int numWarps);
118+
119+
gpu::DistributedEncodingTrait
120+
getDefaultLayoutForTmemLdSt(gpu::MemDescType memType, unsigned numWarps,
121+
gpu::CTALayoutAttr ctaLayout);
122+
123+
std::optional<LinearLayout>
124+
getDistributedLayoutForTmemLdSt(gpu::MemDescType memType, TMemAccessAtom atom,
125+
unsigned numWarps,
126+
gpu::CTALayoutAttr ctaLayout);
82127

83128
} // namespace mlir::triton::nvidia_gpu
84129

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_IR_TENSORMEMORYUTILS_H_
2+
#define TRITON_DIALECT_TRITONNVIDIAGPU_IR_TENSORMEMORYUTILS_H_
3+
4+
#include "mlir/IR/BuiltinTypes.h"
5+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
6+
#include "triton/Tools/LinearLayout.h"
7+
8+
#include <cstdint>
9+
#include <functional>
10+
#include <optional>
11+
12+
namespace mlir::triton::nvidia_gpu {
13+
14+
// Get the maximum number of registers per thread based on the context. This is
15+
// by default 256, but it can be overridden by `ttg.maxnreg` set on the module
16+
// or a contextual register limit set by the compiler on partitions.
17+
int getContextualMaxNReg(Operation *op);
18+
struct TMemLdStEncodingInfo {
19+
TMemAccessAtom atom;
20+
LinearLayout reps;
21+
ColumnAction perm;
22+
int numRegsPerMessage;
23+
std::optional<uint32_t> secondHalfOffset;
24+
std::optional<ColumnAction> broadcast = std::nullopt;
25+
bool unpacked = false;
26+
unsigned vec = 1;
27+
bool padding = false;
28+
};
29+
30+
FailureOr<TMemLdStEncodingInfo>
31+
computeTMemLdStEncodingInfo(RankedTensorType regTy, gpu::MemDescType memTy,
32+
int maxnreg,
33+
std::function<InFlightDiagnostic()> emitError = {});
34+
35+
} // namespace mlir::triton::nvidia_gpu
36+
37+
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_IR_TENSORMEMORYUTILS_H_

include/triton/Tools/LinearLayout.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,25 @@ class LinearLayout {
558558
return reshapeOuts({{*getOutDimNames().begin(), getTotalOutDimSize()}});
559559
}
560560

561+
// Resizes the dimension to one that is smallre or equal to the given size.
562+
// These operations are similar to `sublayout` but at a dimension level.
563+
[[nodiscard]] LinearLayout resizeInDim(StringAttr inDim,
564+
int32_t newSize) const;
565+
[[nodiscard]] LinearLayout resizeOutDim(StringAttr outDim,
566+
int32_t newSize) const;
567+
568+
[[nodiscard]] LinearLayout renameInDim(StringAttr oldDim,
569+
StringAttr newDim) const {
570+
auto bases = getBases();
571+
auto it = bases.find(oldDim);
572+
assert(it != bases.end());
573+
auto value = std::move(it->second);
574+
bases.erase(it);
575+
bases.insert({newDim, std::move(value)});
576+
return LinearLayout(bases, getOutDims(),
577+
/*requireSurjective=*/isSurjective());
578+
}
579+
561580
// Concatenates two layouts by their in (resp. out) dimensions. The layouts
562581
// must have the same output (resp. input) dimensions and sizes and different
563582
// input (resp. output) dimensions. The input dimensions of this layout are

lib/Conversion/TritonToTritonGPU/RelayoutTritonGPU.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,10 @@ namespace ttng = triton::nvidia_gpu;
2121
RankedTensorType getTMEMTensorLayout(const TypeConverter *tc,
2222
RankedTensorType type, MemDescType memdesc,
2323
unsigned numWarps) {
24-
Attribute encoding;
2524
type = cast<RankedTensorType>(tc->convertType(type));
26-
if (isa<ttng::TensorMemoryScalesEncodingAttr>(memdesc.getEncoding())) {
27-
encoding = LinearEncodingAttr::get(
28-
type.getContext(), getScaleTMEMStoreLinearLayout(type, numWarps));
29-
} else {
30-
auto tmemEnc = cast<ttng::TensorMemoryEncodingAttr>(memdesc.getEncoding());
31-
encoding = ttng::getTmemCompatibleLayout(
32-
tmemEnc.getBlockM(), tmemEnc.getBlockN(), type, numWarps);
33-
}
25+
auto ctaLayout = getCTALayout(type.getEncoding());
26+
auto encoding =
27+
ttng::getDefaultLayoutForTmemLdSt(memdesc, numWarps, ctaLayout);
3428
return type.cloneWithEncoding(encoding);
3529
}
3630

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,44 @@ bool DotScaledOp::verifyOutputDims() {
356356
return true;
357357
}
358358

359+
LogicalResult DotScaledOp::verify() {
360+
auto aShape = this->getA().getType().getShape();
361+
int64_t rank = aShape.size();
362+
363+
auto k = aShape[rank - 1];
364+
if (this->getAElemType() == ScaleDotElemType::E2M1) {
365+
if (this->getLhsKPack())
366+
k *= 2;
367+
}
368+
auto cShape = this->getC().getType().getShape();
369+
int64_t mDim = cShape[cShape.size() - 2];
370+
int64_t nDim = cShape[cShape.size() - 1];
371+
372+
if (getAScale()) {
373+
auto aScaleShape = getAScale().getType().getShape();
374+
if (aScaleShape[rank - 2] != mDim)
375+
return this->emitError(
376+
"scales M dimension must match the operand M dimension");
377+
int scale_factor =
378+
isa<Float8E4M3FNType>(getAScale().getType().getElementType()) ? 16 : 32;
379+
if (aScaleShape[rank - 1] != k / scale_factor)
380+
return this->emitError("scales K dimension must match the operand K "
381+
"divided by the scale factor");
382+
}
383+
if (getBScale()) {
384+
auto bScaleShape = getBScale().getType().getShape();
385+
if (bScaleShape[rank - 2] != nDim)
386+
return this->emitError(
387+
"scales N dimension must match the operand N dimension");
388+
int scale_factor =
389+
isa<Float8E4M3FNType>(getBScale().getType().getElementType()) ? 16 : 32;
390+
if (bScaleShape[rank - 1] != k / scale_factor)
391+
return this->emitError("scales K dimension must match the operand K "
392+
"divided by the scale factor");
393+
}
394+
return success();
395+
}
396+
359397
//-- MakeRangeOp --
360398
OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) {
361399
// make_range(start, start + 1) -> constant(start)

0 commit comments

Comments
 (0)