Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,7 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
`lhs` `=` $a_elem_type `rhs` `=` $b_elem_type attr-dict
`:` type($a) (`,` type($a_scale)^)? `*` type($b) (`,` type($b_scale)^)? `->` type($d)
}];
let hasVerifier = 1;
}

//
Expand Down
23 changes: 10 additions & 13 deletions include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,6 @@ chooseDsReadTrLayout(Attribute enc, ArrayRef<int64_t> shape,
int32_t elemBitWidth, unsigned instBitWidth,
unsigned numLanesInShuffleGroup);

LinearLayout getScaleTMEMStoreLinearLayout(RankedTensorType scaleType,
int numWarps);

std::optional<LinearLayout>
getTmemLoadStoreLayout16x256(int M, int N, RankedTensorType oldType,
int numWarps);

// Return a layout valid for TMemLoad op for a tmem layout of block MxN that
// distribute the data long M for the warp groups. This doesn't affect the TMem
// layout it just returns a distributed layout compatible for tmem_load.
LinearLayout getTmemLoadLayoutSplitLongM(int M, int N, RankedTensorType oldType,
int numWarps);

// Create LinearLayout for scale in scaled mfma.
LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
ArrayRef<int64_t> dotOperandShape,
Expand Down Expand Up @@ -161,5 +148,15 @@ std::optional<LinearLayout> chooseMfmaLikeStoreLayout(RankedTensorType valType);
LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared,
bool disableSwizzle);

// Make a LinearLayout that maps a block-id to an N-dimensional index.
//
// The tensor is split up into CTAsPerCGA pieces, which are distributed among
// the CTAsPerCGA CTAs (i.e. blocks) in the CGA (i.e. groups).
//
// See the nomenclature note at the top of the LinearLayoutConversions.cpp file
// for an explanation of why this is called makeCgaLayout when it accepts a
// CTALayoutAttr.
LinearLayout makeCgaLayout(CTALayoutAttr layout);

} // namespace mlir::triton::gpu
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H
59 changes: 52 additions & 7 deletions include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "llvm/Support/ErrorHandling.h"

// TritonNvidiaGPU depends on Triton
#include "triton/Dialect/Triton/IR/Dialect.h"
Expand Down Expand Up @@ -61,24 +62,68 @@ struct TMemAllocation {
int numCols;
};

// Used to describe the layout of the TMEM load/store instructions
enum class TMemAccessAtom { I32x32b, I16x64b, I16x128b, I16x256b, I16x32bx2 };

inline int getElementsPerThread(TMemAccessAtom atom) {
switch (atom) {
case TMemAccessAtom::I32x32b:
case TMemAccessAtom::I16x64b:
case TMemAccessAtom::I16x32bx2:
return 1;
case TMemAccessAtom::I16x128b:
return 2;
case TMemAccessAtom::I16x256b:
return 4;
}
llvm_unreachable("Unknown TMemAccessAtom");
}

inline const char *getOpShape(TMemAccessAtom atom) {
switch (atom) {
case TMemAccessAtom::I32x32b:
return "32x32b";
case TMemAccessAtom::I16x64b:
return "16x64b";
case TMemAccessAtom::I16x128b:
return "16x128b";
case TMemAccessAtom::I16x256b:
return "16x256b";
case TMemAccessAtom::I16x32bx2:
return "16x32bx2";
}
llvm_unreachable("Unknown TMemAccessAtom");
}

LinearLayout getTileLayout(MLIRContext *ctx, TMemAccessAtom atom,
bool unpacked);

TMemAllocation getTmemAllocSizes(gpu::MemDescType memDescType);

gpu::DistributedEncodingTrait getTmemCompatibleLayout(unsigned M, unsigned N,
RankedTensorType oltType,
unsigned numWarps);
gpu::DistributedEncodingTrait
SmallVector<gpu::DistributedEncodingTrait>
getTmemCompatibleLayouts(gpu::MemDescType memType, unsigned numWarps,
ArrayRef<int64_t> ctaSplit = {1, 1});

std::optional<gpu::DistributedEncodingTrait>
getTmemLoadLayoutSplitLongM(RankedTensorType tensorType,
gpu::MemDescType memType, int numWarps);

SmallVector<gpu::DistributedEncodingTrait>
getTmemCompatibleLayouts(Operation *op, RankedTensorType tensorType,
gpu::MemDescType memType);

bool isDistributedLayoutTMemCompatible(Operation *op,
RankedTensorType tensorType,
gpu::MemDescType memType);
bool isDistributedLayoutSplitMTmemLoadStore(RankedTensorType tensorType,
gpu::MemDescType memType,
int numWarps);

gpu::DistributedEncodingTrait
getDefaultLayoutForTmemLdSt(gpu::MemDescType memType, unsigned numWarps,
gpu::CTALayoutAttr ctaLayout);

std::optional<LinearLayout>
getDistributedLayoutForTmemLdSt(gpu::MemDescType memType, TMemAccessAtom atom,
unsigned numWarps,
gpu::CTALayoutAttr ctaLayout);

} // namespace mlir::triton::nvidia_gpu

Expand Down
37 changes: 37 additions & 0 deletions include/triton/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_IR_TENSORMEMORYUTILS_H_
#define TRITON_DIALECT_TRITONNVIDIAGPU_IR_TENSORMEMORYUTILS_H_

#include "mlir/IR/BuiltinTypes.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Tools/LinearLayout.h"

#include <cstdint>
#include <functional>
#include <optional>

namespace mlir::triton::nvidia_gpu {

// Get the maximum number of registers per thread based on the context. This is
// by default 256, but it can be overridden by `ttg.maxnreg` set on the module
// or a contextual register limit set by the compiler on partitions.
int getContextualMaxNReg(Operation *op);
struct TMemLdStEncodingInfo {
TMemAccessAtom atom;
LinearLayout reps;
ColumnAction perm;
int numRegsPerMessage;
std::optional<uint32_t> secondHalfOffset;
std::optional<ColumnAction> broadcast = std::nullopt;
bool unpacked = false;
unsigned vec = 1;
bool padding = false;
};

FailureOr<TMemLdStEncodingInfo>
computeTMemLdStEncodingInfo(RankedTensorType regTy, gpu::MemDescType memTy,
int maxnreg,
std::function<InFlightDiagnostic()> emitError = {});

} // namespace mlir::triton::nvidia_gpu

#endif // TRITON_DIALECT_TRITONNVIDIAGPU_IR_TENSORMEMORYUTILS_H_
19 changes: 19 additions & 0 deletions include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,25 @@ class LinearLayout {
return reshapeOuts({{*getOutDimNames().begin(), getTotalOutDimSize()}});
}

// Resizes the dimension to one that is smallre or equal to the given size.
// These operations are similar to `sublayout` but at a dimension level.
[[nodiscard]] LinearLayout resizeInDim(StringAttr inDim,
int32_t newSize) const;
[[nodiscard]] LinearLayout resizeOutDim(StringAttr outDim,
int32_t newSize) const;

[[nodiscard]] LinearLayout renameInDim(StringAttr oldDim,
StringAttr newDim) const {
auto bases = getBases();
auto it = bases.find(oldDim);
assert(it != bases.end());
auto value = std::move(it->second);
bases.erase(it);
bases.insert({newDim, std::move(value)});
return LinearLayout(bases, getOutDims(),
/*requireSurjective=*/isSurjective());
}

// Concatenates two layouts by their in (resp. out) dimensions. The layouts
// must have the same output (resp. input) dimensions and sizes and different
// input (resp. output) dimensions. The input dimensions of this layout are
Expand Down
12 changes: 3 additions & 9 deletions lib/Conversion/TritonToTritonGPU/RelayoutTritonGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,10 @@ namespace ttng = triton::nvidia_gpu;
RankedTensorType getTMEMTensorLayout(const TypeConverter *tc,
RankedTensorType type, MemDescType memdesc,
unsigned numWarps) {
Attribute encoding;
type = cast<RankedTensorType>(tc->convertType(type));
if (isa<ttng::TensorMemoryScalesEncodingAttr>(memdesc.getEncoding())) {
encoding = LinearEncodingAttr::get(
type.getContext(), getScaleTMEMStoreLinearLayout(type, numWarps));
} else {
auto tmemEnc = cast<ttng::TensorMemoryEncodingAttr>(memdesc.getEncoding());
encoding = ttng::getTmemCompatibleLayout(
tmemEnc.getBlockM(), tmemEnc.getBlockN(), type, numWarps);
}
auto ctaLayout = getCTALayout(type.getEncoding());
auto encoding =
ttng::getDefaultLayoutForTmemLdSt(memdesc, numWarps, ctaLayout);
return type.cloneWithEncoding(encoding);
}

Expand Down
38 changes: 38 additions & 0 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,44 @@ bool DotScaledOp::verifyOutputDims() {
return true;
}

LogicalResult DotScaledOp::verify() {
auto aShape = this->getA().getType().getShape();
int64_t rank = aShape.size();

auto k = aShape[rank - 1];
if (this->getAElemType() == ScaleDotElemType::E2M1) {
if (this->getLhsKPack())
k *= 2;
}
auto cShape = this->getC().getType().getShape();
int64_t mDim = cShape[cShape.size() - 2];
int64_t nDim = cShape[cShape.size() - 1];

if (getAScale()) {
auto aScaleShape = getAScale().getType().getShape();
if (aScaleShape[rank - 2] != mDim)
return this->emitError(
"scales M dimension must match the operand M dimension");
int scale_factor =
isa<Float8E4M3FNType>(getAScale().getType().getElementType()) ? 16 : 32;
if (aScaleShape[rank - 1] != k / scale_factor)
return this->emitError("scales K dimension must match the operand K "
"divided by the scale factor");
}
if (getBScale()) {
auto bScaleShape = getBScale().getType().getShape();
if (bScaleShape[rank - 2] != nDim)
return this->emitError(
"scales N dimension must match the operand N dimension");
int scale_factor =
isa<Float8E4M3FNType>(getBScale().getType().getElementType()) ? 16 : 32;
if (bScaleShape[rank - 1] != k / scale_factor)
return this->emitError("scales K dimension must match the operand K "
"divided by the scale factor");
}
return success();
}

//-- MakeRangeOp --
OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) {
// make_range(start, start + 1) -> constant(start)
Expand Down
Loading
Loading