Skip to content

Commit 1c2f9d0

Browse files
Merge commit 'c172d539a2f412eaec7f508c81e0cf1f21e95ede'
2 parents 0499edd + c172d53 commit 1c2f9d0

File tree

78 files changed

+2374
-1911
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

78 files changed

+2374
-1911
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ cmake-build-*
7373
cuobjdump
7474
nvdisasm
7575
ptxas
76+
ptxas-blackwell
7677

7778
# Third-party include
7879
third_party/nvidia/backend/include

cmake/nvidia-toolchain-version.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
{
2+
"ptxas-blackwell": "12.9.86",
23
"ptxas": "12.8.93",
34
"cuobjdump": "12.8.55",
45
"nvdisasm": "12.8.55",

include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 28 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
5757
// computation is eliminated.
5858
SmallVector<Value> maybeDeduplicate(SourceOp op,
5959
SmallVector<Value> resultVals) const {
60+
auto ctx = op.getContext();
6061
if (!isMemoryEffectFree(op))
6162
// the op has side effects: can't dedup
6263
return resultVals;
@@ -65,104 +66,45 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
6566
// there must be exactly 1 result
6667
return resultVals;
6768
Value result = results[0];
68-
Type type = result.getType();
69-
if (!type)
70-
return resultVals;
71-
RankedTensorType rtType = dyn_cast<RankedTensorType>(type);
69+
RankedTensorType rtType = dyn_cast<RankedTensorType>(result.getType());
7270
if (!rtType)
7371
// the result must be a tensor
7472
return resultVals;
75-
Attribute encoding = rtType.getEncoding();
76-
if (!encoding)
77-
// encoding not available
78-
return resultVals;
79-
Attribute baseEncoding = encoding;
80-
if (isa<AMDMfmaEncodingAttr>(baseEncoding) ||
81-
isa<AMDWmmaEncodingAttr>(baseEncoding))
82-
// TODO: this logic seems incorrect for mfma and wmma layout. Skip for
83-
// now. We saw mismatches for some flash-attention and dot tests on AMD
84-
// backend. Note that this logic works for sliced layout whose parent is
85-
// mfma layout. Therefore, this is not combined with the following check.
86-
return resultVals;
87-
while (auto sliced = dyn_cast<SliceEncodingAttr>(baseEncoding))
88-
baseEncoding = sliced.getParent();
89-
if (isa<LinearEncodingAttr, DotOperandEncodingAttr>(baseEncoding)) {
90-
// TODO: this logic seems incorrect for mma layout. Skip for now.
91-
// The following test crashes and some other miscompile:
92-
// test_core::test_fp8_dot_acc
93-
return resultVals;
94-
}
9573

96-
SmallVector<unsigned> elemsPerThread = getElemsPerThread(rtType);
97-
int rank = elemsPerThread.size();
98-
if (product<unsigned>(elemsPerThread) != resultVals.size())
99-
return resultVals;
74+
// Bail out if we don't have the constancy analysis
10075
AxisInfo *axisInfo = axisAnalysisPass.getAxisInfo(result);
10176
if (!axisInfo)
102-
// axis info (e.g., constancy) not available
103-
return resultVals;
104-
SmallVector<unsigned> contigPerThread = getContigPerThread(rtType);
105-
if (rank != contigPerThread.size())
10677
return resultVals;
107-
10878
SmallVector<int64_t> constancy = axisInfo->getConstancy();
109-
if (rank != constancy.size())
110-
return resultVals;
111-
bool hasConstancy = false;
112-
for (int i = 0; i < rank; ++i) {
113-
if (constancy[i] > contigPerThread[i]) {
114-
if (constancy[i] % contigPerThread[i] != 0)
115-
// constancy is not evenly covered by contigPerThread
116-
return resultVals;
117-
// can't move the values across different
118-
// "contigPerThread"-sized blocks
119-
constancy[i] = contigPerThread[i];
120-
}
121-
if (elemsPerThread[i] < 1 || constancy[i] < 1)
122-
return resultVals;
123-
if (!(elemsPerThread[i] % constancy[i] == 0 ||
124-
constancy[i] % elemsPerThread[i] == 0))
125-
// either the constancy along each dimension must fit
126-
// into the elemsPerThread or the other way around
127-
return resultVals;
128-
if (constancy[i] > 1)
129-
hasConstancy = true;
130-
}
131-
if (!hasConstancy)
132-
// nothing to deduplicate
133-
return resultVals;
13479

135-
if (rank > 1) {
136-
// reorder the shape and constancy vectors by the axis order:
137-
// from the fastest-changing to the smallest-changing axis
138-
SmallVector<unsigned> order = getOrder(rtType);
139-
if (rank != order.size())
140-
return resultVals;
141-
elemsPerThread = applyPermutation(elemsPerThread, order);
142-
constancy = applyPermutation(constancy, order);
143-
}
80+
if (llvm::all_of(constancy, [](int64_t c) { return c == 1; }))
81+
return resultVals;
14482

145-
SmallVector<unsigned> strides(rank, 1);
146-
for (int i = 1; i < rank; ++i) {
147-
strides[i] = strides[i - 1] * elemsPerThread[i - 1];
148-
}
149-
SmallVector<Value> dedupResultVals;
150-
dedupResultVals.reserve(resultVals.size());
151-
for (int i = 0; i < resultVals.size(); ++i) {
152-
// each coordinate of the orig_idx is "coarsened" using the
153-
// constancy along this dimension: the resulting dedup_idx
154-
// points to the reused value in the original resultsVal
155-
int orig_idx = i;
156-
int dedup_idx = 0;
157-
for (int j = 0; j < rank; ++j) {
158-
int coord_j = orig_idx % elemsPerThread[j];
159-
dedup_idx += (coord_j / constancy[j] * constancy[j]) * strides[j];
160-
orig_idx /= elemsPerThread[j];
83+
// We zero out the bases that are constant
84+
auto kReg = StringAttr::get(ctx, "register");
85+
auto ll = toLinearLayout(rtType);
86+
auto dims = to_vector(ll.getOutDimNames());
87+
auto llReg = ll.sublayout({kReg}, dims);
88+
auto inv = ll.pseudoinvert();
89+
auto invReg = inv.sublayout(dims, {kReg});
90+
auto bases_inv = invReg.getBases();
91+
for (auto [c, d] : llvm::zip(constancy, dims)) {
92+
assert(llvm::isPowerOf2_32(c));
93+
for (int i = 0; i < llvm::Log2_32(c); i++) {
94+
bases_inv[d][i] = {0};
16195
}
162-
dedupResultVals.push_back(resultVals[dedup_idx]);
16396
}
164-
165-
return dedupResultVals;
97+
auto invBroadcast =
98+
LinearLayout(bases_inv, invReg.getOutDims(), /*isSurjective=*/false);
99+
auto cvt = llReg.compose(invBroadcast);
100+
101+
// Deduplicate the result values
102+
SmallVector<Value> outVals(resultVals.size());
103+
for (int i = 0; i < outVals.size(); i++) {
104+
auto srcIdx = cvt.apply({{kReg, i}}).begin()->second;
105+
outVals[i] = resultVals[srcIdx];
106+
}
107+
return outVals;
166108
}
167109
LogicalResult
168110
matchAndRewrite(SourceOp op, OpAdaptor adaptor,

include/triton/Dialect/Triton/IR/TritonAttrDefs.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ def TT_InputPrecisionAttr : I32EnumAttr<
129129
[
130130
I32EnumAttrCase<"TF32", 0, "tf32">,
131131
I32EnumAttrCase<"TF32x3", 1, "tf32x3">,
132-
I32EnumAttrCase<"IEEE", 2, "ieee">
132+
I32EnumAttrCase<"IEEE", 2, "ieee">,
133+
I32EnumAttrCase<"BF16x3", 3, "bf16x3">,
134+
I32EnumAttrCase<"BF16x6", 4, "bf16x6">
133135
]>{
134136
let cppNamespace = "::mlir::triton";
135137
}

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -664,9 +664,11 @@ def TT_DotOp : TT_Op<"dot", [Pure,
664664

665665
let description = [{
666666
$d = matrix_multiply($a, $b) + $c. $inputPrecision describes how to exercise the TC
667-
when the inputs are f32. It can be one of: tf32, tf32x3, ieee.
667+
when the inputs are f32. It can be one of: tf32, tf32x3, ieee, bf16x3, bf16x6.
668668
tf32: use TC with tf32 ops.
669669
tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp
670+
bf16x3: implement the 3xBF16 trick. For more info see the pass in F32DotTC.cpp
671+
bf16x6: implement the 6xBF16 trick. For more info see the pass in F32DotTC.cpp
670672
ieee: don't use TC, implement dot in software.
671673
If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored.
672674
}];

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,19 @@ chooseDsReadTrLayout(Attribute enc, ArrayRef<int64_t> shape,
117117
int32_t elemBitWidth, unsigned instBitWidth,
118118
unsigned numLanesInShuffleGroup);
119119

120+
LinearLayout getScaleTMEMStoreLinearLayout(RankedTensorType scaleType,
121+
int numWarps);
122+
123+
std::optional<LinearLayout>
124+
getTmemLoadStoreLayout16x256(int M, int N, RankedTensorType oldType,
125+
int numWarps);
126+
127+
// Return a layout valid for TMemLoad op for a tmem layout of block MxN that
128+
// distribute the data long M for the warp groups. This doesn't affect the TMem
129+
// layout it just returns a distributed layout compatible for tmem_load.
130+
LinearLayout getTmemLoadLayoutSplitLongM(int M, int N, RankedTensorType oldType,
131+
int numWarps);
132+
120133
// Create LinearLayout for scale in scaled mfma.
121134
LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
122135
ArrayRef<int64_t> dotOperandShape,
@@ -129,12 +142,10 @@ LinearLayout chooseScaledWmmaScaleLayout(
129142
const std::vector<std::vector<int32_t>> &dotOperandWarpBasis,
130143
ArrayRef<int64_t> dotOperandShape);
131144

132-
LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx, int dotOperandIdx,
133-
ArrayRef<int64_t> dotOperandShape,
134-
ArrayRef<unsigned> tilesPerWarp,
145+
LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx,
146+
ArrayRef<int64_t> shape, int opIdx,
135147
ArrayRef<unsigned> warpsPerCTA,
136-
unsigned instrM, unsigned instrN,
137-
CTALayoutAttr ctaLayoutAttr);
148+
CTALayoutAttr ctaLayout);
138149

139150
// Create LinearLayout for nvidia mma tile.
140151
LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
@@ -151,15 +162,5 @@ std::optional<LinearLayout> chooseMfmaLikeStoreLayout(RankedTensorType valType);
151162
LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared,
152163
bool disableSwizzle);
153164

154-
// Make a LinearLayout that maps a block-id to an N-dimensional index.
155-
//
156-
// The tensor is split up into CTAsPerCGA pieces, which are distributed among
157-
// the CTAsPerCGA CTAs (i.e. blocks) in the CGA (i.e. groups).
158-
//
159-
// See the nomenclature note at the top of the LinearLayoutConversions.cpp file
160-
// for an explanation of why this is called makeCgaLayout when it accepts a
161-
// CTALayoutAttr.
162-
LinearLayout makeCgaLayout(CTALayoutAttr layout);
163-
164165
} // namespace mlir::triton::gpu
165166
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,15 +177,22 @@ def TritonGPUPartitionScheduling : Pass<"tritongpu-partition-scheduling", "mlir:
177177
}
178178

179179
def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> {
180-
let summary = "3xTF32 trick";
180+
let summary = "Emulate dot-product tensor core precision using TF32s or BF16s";
181181

182182
let description = [{
183-
Decompose fp32 `DotOp` instructions into 4 pointwise ops and 3 fp16 `DotOp`s
184-
to allow using TensorCores. See https://github.com/NVIDIA/cutlass/discussions/385
183+
Generic pass to emulate/decompose f32 `DotOp` instructions.
184+
* Decompose fp32 `DotOp` instructions into 4 pointwise ops and 3 fp16 `DotOp`s
185+
to allow using TensorCores. See https://github.com/NVIDIA/cutlass/discussions/385.
186+
* Decompose fp32 `DotOp` instructions into BF16 operations.
187+
See https://arxiv.org/abs/1904.06376
185188
}];
186189

187-
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
188-
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"];
190+
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
191+
let options = [
192+
Option<"emuTF32", "emu-tf32",
193+
"bool", /*default*/"false",
194+
"whether to handle InputPrecision TF32xN for Nvidia GPUs">
195+
];
189196
}
190197

191198
def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {

include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h

Lines changed: 7 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
#include "mlir/IR/BuiltinOps.h"
3030
#include "mlir/IR/BuiltinTypes.h"
3131
#include "mlir/IR/Dialect.h"
32-
#include "llvm/Support/ErrorHandling.h"
3332

3433
// TritonNvidiaGPU depends on Triton
3534
#include "triton/Dialect/Triton/IR/Dialect.h"
@@ -62,68 +61,24 @@ struct TMemAllocation {
6261
int numCols;
6362
};
6463

65-
// Used to describe the layout of the TMEM load/store instructions
66-
enum class TMemAccessAtom { I32x32b, I16x64b, I16x128b, I16x256b, I16x32bx2 };
67-
68-
inline int getElementsPerThread(TMemAccessAtom atom) {
69-
switch (atom) {
70-
case TMemAccessAtom::I32x32b:
71-
case TMemAccessAtom::I16x64b:
72-
case TMemAccessAtom::I16x32bx2:
73-
return 1;
74-
case TMemAccessAtom::I16x128b:
75-
return 2;
76-
case TMemAccessAtom::I16x256b:
77-
return 4;
78-
}
79-
llvm_unreachable("Unknown TMemAccessAtom");
80-
}
81-
82-
inline const char *getOpShape(TMemAccessAtom atom) {
83-
switch (atom) {
84-
case TMemAccessAtom::I32x32b:
85-
return "32x32b";
86-
case TMemAccessAtom::I16x64b:
87-
return "16x64b";
88-
case TMemAccessAtom::I16x128b:
89-
return "16x128b";
90-
case TMemAccessAtom::I16x256b:
91-
return "16x256b";
92-
case TMemAccessAtom::I16x32bx2:
93-
return "16x32bx2";
94-
}
95-
llvm_unreachable("Unknown TMemAccessAtom");
96-
}
97-
98-
LinearLayout getTileLayout(MLIRContext *ctx, TMemAccessAtom atom,
99-
bool unpacked);
100-
10164
TMemAllocation getTmemAllocSizes(gpu::MemDescType memDescType);
10265

103-
SmallVector<gpu::DistributedEncodingTrait>
104-
getTmemCompatibleLayouts(gpu::MemDescType memType, unsigned numWarps,
105-
ArrayRef<int64_t> ctaSplit = {1, 1});
106-
107-
std::optional<gpu::DistributedEncodingTrait>
66+
gpu::DistributedEncodingTrait getTmemCompatibleLayout(unsigned M, unsigned N,
67+
RankedTensorType oltType,
68+
unsigned numWarps);
69+
gpu::DistributedEncodingTrait
10870
getTmemLoadLayoutSplitLongM(RankedTensorType tensorType,
10971
gpu::MemDescType memType, int numWarps);
110-
11172
SmallVector<gpu::DistributedEncodingTrait>
11273
getTmemCompatibleLayouts(Operation *op, RankedTensorType tensorType,
11374
gpu::MemDescType memType);
11475

11576
bool isDistributedLayoutTMemCompatible(Operation *op,
11677
RankedTensorType tensorType,
11778
gpu::MemDescType memType);
118-
119-
gpu::DistributedEncodingTrait
120-
getDefaultLayoutForTmemLdSt(gpu::MemDescType memType, unsigned numWarps,
121-
gpu::CTALayoutAttr ctaLayout);
122-
123-
std::optional<LinearLayout>
124-
getDistributedLayoutForTmemLdSt(gpu::MemDescType memType, TMemAccessAtom atom,
125-
unsigned numWarps,
126-
gpu::CTALayoutAttr ctaLayout);
79+
bool isDistributedLayoutSplitMTmemLoadStore(RankedTensorType tensorType,
80+
gpu::MemDescType memType,
81+
int numWarps);
12782

12883
} // namespace mlir::triton::nvidia_gpu
12984

include/triton/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.h

Lines changed: 0 additions & 37 deletions
This file was deleted.

include/triton/Tools/LinearLayout.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -558,18 +558,6 @@ class LinearLayout {
558558
return reshapeOuts({{*getOutDimNames().begin(), getTotalOutDimSize()}});
559559
}
560560

561-
[[nodiscard]] LinearLayout renameInDim(StringAttr oldDim,
562-
StringAttr newDim) const {
563-
auto bases = getBases();
564-
auto it = bases.find(oldDim);
565-
assert(it != bases.end());
566-
auto value = std::move(it->second);
567-
bases.erase(it);
568-
bases.insert({newDim, std::move(value)});
569-
return LinearLayout(bases, getOutDims(),
570-
/*requireSurjective=*/isSurjective());
571-
}
572-
573561
// Concatenates two layouts by their in (resp. out) dimensions. The layouts
574562
// must have the same output (resp. input) dimensions and sizes and different
575563
// input (resp. output) dimensions. The input dimensions of this layout are

0 commit comments

Comments
 (0)