Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,71 @@ triton::gpu::BlockedEncodingAttr
getDefaultBlockedEncoding(MLIRContext *context, ArrayRef<int64_t> shape,
int numWarps, int threadsPerWarp, int numCTAs);

// For each output dimension d, ensure that the layout's output size (i.e., its
// codomain) does not exceed shape[d]. Do this without changing the size of the
// layout's inputs (i.e., leave its domain unchanged).
//
// This function is invariant to the order of the layout's input and output
// dimensions.
//
// We achieve this by setting the largest value in each output dimension d to 0
// because bases that map to a location larger than shape[d]
// effectively duplicate along that dimension. For example, consider a layout
// with an output dimension size of 32, and we call ensureLayoutNotLargerThan to
// shrink the output dimension size to 8:
//
// L(register=1) = 8
// L(register=2) = 4
// L(register=4) = 1
// L(lane=1) = 2
// L(lane=2) = 16
//
// In the first step, we shrink the output dimension size to 16 by setting
// L(lane=2) to 0:
//
// L(register=1) = 8
// L(register=2) = 4
// L(register=4) = 1
// L(lane=1) = 2
// L(lane=2) = 0
//
// This means that lane=2 has the same data as lane=0.
//
// Now the output dimension of this layout has a size of 16, which is still
// larger than 8. We find the current largest value in the output dimension,
// which is L(register=1) = 8, and we set L(register=1) to 0:
//
// L(register=1) = 0
// L(register=2) = 4
// L(register=4) = 1
// L(lane=1) = 2
// L(lane=2) = 0
//
// Now the output dimension of this layout has a size of 8, which is the desired
// size. Note that this method works only because the bases are powers of two,
// which is the case for DistributedLayouts If broadcastRegisters is false, we
// remove any register that's larger than the desired shape. In the example
// above we would have
// L(register=1) = 4
// L(register=2) = 1
// L(lane=1) = 2
// L(lane=2) = 0
LinearLayout
ensureLayoutNotLargerThan(const LinearLayout &layout,
const llvm::SmallDenseMap<StringAttr, int64_t> &shape,
bool broadcastRegisters = true);

// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no
// smaller than shape[d]. Do this by increasing the size of the layout's inputs
// along its most-minor dimension ("register" for register layouts, "offset" for
// shared layouts).
//
// This function is invariant to the order of the layout's input dimensions, but
// it cares about the order of the output dims, which should be minor-to-major.
LinearLayout ensureLayoutNotSmallerThan(
const LinearLayout &layout,
const llvm::SmallDenseMap<StringAttr, int64_t> &shape);

// Dump information about which threads/registers contain each of the tensor
// elements.
void dumpLayout(RankedTensorType tensorType);
Expand Down
30 changes: 28 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ Right now, Triton implements two main classes of layouts: shared, and distribute
code extraBaseClassDeclaration = [{
unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const;
}];
}

Expand Down Expand Up @@ -147,7 +146,6 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to
let genVerifyDecl = 1;
let skipDefaultBuilders = 1;
}

//===----------------------------------------------------------------------===//
// Shared Layout Encoding
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -565,6 +563,34 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
}];
}

//===----------------------------------------------------------------------===//
// Linear Layout Encoding
//===----------------------------------------------------------------------===//

def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> {
let mnemonic = "linear";

let description = [{
See the docs in LinearLayout.h for the definition of linear layouts.
}];

let parameters = (ins "LinearLayout":$linearLayout);

let extraClassDeclaration = extraDistributedDeclaration # [{
SmallVector<unsigned> getContigPerThread() const;
SmallVector<unsigned> getOrder() const;
}];

let genVerifyDecl = 1;
// Example of assembly format:
// <{register = [[0, 1], [8, 0], [0, 8], [64, 0]],
// lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]],
// warp = [[16, 0], [32, 0]],
// block = []}>
let hasCustomAssemblyFormat = 1;
}


//===----------------------------------------------------------------------===//
// Blocked Layout Encoding
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <vector>

#include "mlir/IR/BuiltinAttributes.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
Expand Down Expand Up @@ -432,6 +433,7 @@ class LinearLayout {
// (e.g. by reshaping) then the order doesn't really affect anything.
auto getInDimNames() const { return llvm::make_first_range(bases); }
auto getOutDimNames() const { return llvm::make_first_range(outDims); }
auto getOutDimSizes() const { return llvm::make_second_range(outDims); }

// Gets the position that this outDim occupies in getOutDimNames(). Asserts
// if the dim is not present.
Expand Down Expand Up @@ -693,6 +695,7 @@ class LinearLayout {
return !(lhs == rhs);
}
bool equalIgnoringOutDimSizes(const LinearLayout &other) const;
friend size_t hash_value(const LinearLayout &layout);

private:
// Factory function that gracefully fails rather than asserts if the layout is
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
if (isa<BlockedEncodingAttr>(layout)) {
return true;
}
if (isa<LinearEncodingAttr>(layout)) {
return true;
}
if (auto slice = dyn_cast<SliceEncodingAttr>(layout)) {
return layoutIsOK(slice.getParent());
}
Expand Down
5 changes: 2 additions & 3 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
if (isa<SharedEncodingAttr>(srcLayout) &&
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
dstLayout) ||
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
LinearEncodingAttr>(dstLayout) ||
isSupportedDotOpLayout(dstTy))) {
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
rewriter);
Expand Down Expand Up @@ -206,7 +206,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
auto dstTy = op.getResult().getType();
auto dstShape = dstTy.getShape();
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
auto dstLayout = dstTy.getEncoding();
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) &&
"Unexpected rank of ConvertLayout(shared->distributed)");

Expand Down
Loading
Loading