Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,7 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
`lhs` `=` $a_elem_type `rhs` `=` $b_elem_type attr-dict
`:` type($a) (`,` type($a_scale)^)? `*` type($b) (`,` type($b_scale)^)? `->` type($d)
}];
let hasVerifier = 1;
}

//
Expand Down
23 changes: 10 additions & 13 deletions include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,6 @@ chooseDsReadTrLayout(Attribute enc, ArrayRef<int64_t> shape,
int32_t elemBitWidth, unsigned instBitWidth,
unsigned numLanesInShuffleGroup);

LinearLayout getScaleTMEMStoreLinearLayout(RankedTensorType scaleType,
int numWarps);

std::optional<LinearLayout>
getTmemLoadStoreLayout16x256(int M, int N, RankedTensorType oldType,
int numWarps);

// Return a layout valid for TMemLoad op for a tmem layout of block MxN that
// distribute the data long M for the warp groups. This doesn't affect the TMem
// layout it just returns a distributed layout compatible for tmem_load.
LinearLayout getTmemLoadLayoutSplitLongM(int M, int N, RankedTensorType oldType,
int numWarps);

// Create LinearLayout for scale in scaled mfma.
LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
ArrayRef<int64_t> dotOperandShape,
Expand Down Expand Up @@ -161,5 +148,15 @@ std::optional<LinearLayout> chooseMfmaLikeStoreLayout(RankedTensorType valType);
LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared,
bool disableSwizzle);

// Make a LinearLayout that maps a block-id to an N-dimensional index.
//
// The tensor is split up into CTAsPerCGA pieces, which are distributed among
// the CTAsPerCGA CTAs (i.e. blocks) in the CGA (i.e. groups).
//
// See the nomenclature note at the top of the LinearLayoutConversions.cpp file
// for an explanation of why this is called makeCgaLayout when it accepts a
// CTALayoutAttr.
LinearLayout makeCgaLayout(CTALayoutAttr layout);

} // namespace mlir::triton::gpu
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H
59 changes: 52 additions & 7 deletions include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "llvm/Support/ErrorHandling.h"

// TritonNvidiaGPU depends on Triton
#include "triton/Dialect/Triton/IR/Dialect.h"
Expand Down Expand Up @@ -61,24 +62,68 @@ struct TMemAllocation {
int numCols;
};

// Used to describe the layout of the TMEM load/store instructions
enum class TMemAccessAtom { I32x32b, I16x64b, I16x128b, I16x256b, I16x32bx2 };

inline int getElementsPerThread(TMemAccessAtom atom) {
switch (atom) {
case TMemAccessAtom::I32x32b:
case TMemAccessAtom::I16x64b:
case TMemAccessAtom::I16x32bx2:
return 1;
case TMemAccessAtom::I16x128b:
return 2;
case TMemAccessAtom::I16x256b:
return 4;
}
llvm_unreachable("Unknown TMemAccessAtom");
}

inline const char *getOpShape(TMemAccessAtom atom) {
switch (atom) {
case TMemAccessAtom::I32x32b:
return "32x32b";
case TMemAccessAtom::I16x64b:
return "16x64b";
case TMemAccessAtom::I16x128b:
return "16x128b";
case TMemAccessAtom::I16x256b:
return "16x256b";
case TMemAccessAtom::I16x32bx2:
return "16x32bx2";
}
llvm_unreachable("Unknown TMemAccessAtom");
}

LinearLayout getTileLayout(MLIRContext *ctx, TMemAccessAtom atom,
bool unpacked);

TMemAllocation getTmemAllocSizes(gpu::MemDescType memDescType);

gpu::DistributedEncodingTrait getTmemCompatibleLayout(unsigned M, unsigned N,
RankedTensorType oltType,
unsigned numWarps);
gpu::DistributedEncodingTrait
SmallVector<gpu::DistributedEncodingTrait>
getTmemCompatibleLayouts(gpu::MemDescType memType, unsigned numWarps,
ArrayRef<int64_t> ctaSplit = {1, 1});

std::optional<gpu::DistributedEncodingTrait>
getTmemLoadLayoutSplitLongM(RankedTensorType tensorType,
gpu::MemDescType memType, int numWarps);

SmallVector<gpu::DistributedEncodingTrait>
getTmemCompatibleLayouts(Operation *op, RankedTensorType tensorType,
gpu::MemDescType memType);

bool isDistributedLayoutTMemCompatible(Operation *op,
RankedTensorType tensorType,
gpu::MemDescType memType);
bool isDistributedLayoutSplitMTmemLoadStore(RankedTensorType tensorType,
gpu::MemDescType memType,
int numWarps);

gpu::DistributedEncodingTrait
getDefaultLayoutForTmemLdSt(gpu::MemDescType memType, unsigned numWarps,
gpu::CTALayoutAttr ctaLayout);

std::optional<LinearLayout>
getDistributedLayoutForTmemLdSt(gpu::MemDescType memType, TMemAccessAtom atom,
unsigned numWarps,
gpu::CTALayoutAttr ctaLayout);

} // namespace mlir::triton::nvidia_gpu

Expand Down
37 changes: 37 additions & 0 deletions include/triton/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_IR_TENSORMEMORYUTILS_H_
#define TRITON_DIALECT_TRITONNVIDIAGPU_IR_TENSORMEMORYUTILS_H_

#include "mlir/IR/BuiltinTypes.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Tools/LinearLayout.h"

#include <cstdint>
#include <functional>
#include <optional>

namespace mlir::triton::nvidia_gpu {

// Get the maximum number of registers per thread based on the context. This is
// by default 256, but it can be overridden by `ttg.maxnreg` set on the module
// or a contextual register limit set by the compiler on partitions.
int getContextualMaxNReg(Operation *op);
struct TMemLdStEncodingInfo {
TMemAccessAtom atom;
LinearLayout reps;
ColumnAction perm;
int numRegsPerMessage;
std::optional<uint32_t> secondHalfOffset;
std::optional<ColumnAction> broadcast = std::nullopt;
bool unpacked = false;
unsigned vec = 1;
bool padding = false;
};

FailureOr<TMemLdStEncodingInfo>
computeTMemLdStEncodingInfo(RankedTensorType regTy, gpu::MemDescType memTy,
int maxnreg,
std::function<InFlightDiagnostic()> emitError = {});

} // namespace mlir::triton::nvidia_gpu

#endif // TRITON_DIALECT_TRITONNVIDIAGPU_IR_TENSORMEMORYUTILS_H_
19 changes: 19 additions & 0 deletions include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,25 @@ class LinearLayout {
return reshapeOuts({{*getOutDimNames().begin(), getTotalOutDimSize()}});
}

// Resizes the dimension to one that is smallre or equal to the given size.
// These operations are similar to `sublayout` but at a dimension level.
[[nodiscard]] LinearLayout resizeInDim(StringAttr inDim,
int32_t newSize) const;
[[nodiscard]] LinearLayout resizeOutDim(StringAttr outDim,
int32_t newSize) const;

[[nodiscard]] LinearLayout renameInDim(StringAttr oldDim,
StringAttr newDim) const {
auto bases = getBases();
auto it = bases.find(oldDim);
assert(it != bases.end());
auto value = std::move(it->second);
bases.erase(it);
bases.insert({newDim, std::move(value)});
return LinearLayout(bases, getOutDims(),
/*requireSurjective=*/isSurjective());
}

// Concatenates two layouts by their in (resp. out) dimensions. The layouts
// must have the same output (resp. input) dimensions and sizes and different
// input (resp. output) dimensions. The input dimensions of this layout are
Expand Down
12 changes: 3 additions & 9 deletions lib/Conversion/TritonToTritonGPU/RelayoutTritonGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,10 @@ namespace ttng = triton::nvidia_gpu;
RankedTensorType getTMEMTensorLayout(const TypeConverter *tc,
RankedTensorType type, MemDescType memdesc,
unsigned numWarps) {
Attribute encoding;
type = cast<RankedTensorType>(tc->convertType(type));
if (isa<ttng::TensorMemoryScalesEncodingAttr>(memdesc.getEncoding())) {
encoding = LinearEncodingAttr::get(
type.getContext(), getScaleTMEMStoreLinearLayout(type, numWarps));
} else {
auto tmemEnc = cast<ttng::TensorMemoryEncodingAttr>(memdesc.getEncoding());
encoding = ttng::getTmemCompatibleLayout(
tmemEnc.getBlockM(), tmemEnc.getBlockN(), type, numWarps);
}
auto ctaLayout = getCTALayout(type.getEncoding());
auto encoding =
ttng::getDefaultLayoutForTmemLdSt(memdesc, numWarps, ctaLayout);
return type.cloneWithEncoding(encoding);
}

Expand Down
38 changes: 38 additions & 0 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,44 @@ bool DotScaledOp::verifyOutputDims() {
return true;
}

LogicalResult DotScaledOp::verify() {
auto aShape = this->getA().getType().getShape();
int64_t rank = aShape.size();

auto k = aShape[rank - 1];
if (this->getAElemType() == ScaleDotElemType::E2M1) {
if (this->getLhsKPack())
k *= 2;
}
auto cShape = this->getC().getType().getShape();
int64_t mDim = cShape[cShape.size() - 2];
int64_t nDim = cShape[cShape.size() - 1];

if (getAScale()) {
auto aScaleShape = getAScale().getType().getShape();
if (aScaleShape[rank - 2] != mDim)
return this->emitError(
"scales M dimension must match the operand M dimension");
int scale_factor =
isa<Float8E4M3FNType>(getAScale().getType().getElementType()) ? 16 : 32;
if (aScaleShape[rank - 1] != k / scale_factor)
return this->emitError("scales K dimension must match the operand K "
"divided by the scale factor");
}
if (getBScale()) {
auto bScaleShape = getBScale().getType().getShape();
if (bScaleShape[rank - 2] != nDim)
return this->emitError(
"scales N dimension must match the operand N dimension");
int scale_factor =
isa<Float8E4M3FNType>(getBScale().getType().getElementType()) ? 16 : 32;
if (bScaleShape[rank - 1] != k / scale_factor)
return this->emitError("scales K dimension must match the operand K "
"divided by the scale factor");
}
return success();
}

//-- MakeRangeOp --
OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) {
// make_range(start, start + 1) -> constant(start)
Expand Down
Loading
Loading