Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bin/triton-tensor-layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ using namespace mlir;
// clang-format off
// Example usage:
//
// triton-tensor-layout -l "#ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>"
// triton-tensor-layout -l "#ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>"
//
// triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt
//
Expand Down
1 change: 0 additions & 1 deletion include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,6 @@ using ::mlir::LLVM::delinearize;
using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
using ::mlir::triton::gpu::AMDWmmaEncodingAttr;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::CTALayoutAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::NvidiaMmaEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
Expand Down
1 change: 1 addition & 0 deletions include/triton/Dialect/TritonGPU/IR/Attributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_

#include "mlir/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/CTAEncodingAttr.h"
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"

#define GET_ATTRDEF_CLASSES
Expand Down
8 changes: 7 additions & 1 deletion include/triton/Dialect/TritonGPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@ set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td)
mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs)
mlir_tablegen(AttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(AttrDefs.cpp.inc -gen-attrdef-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)

set(LLVM_TARGET_DEFINITIONS TritonGPUAttrImpls.td)
mlir_tablegen(AttrDefs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(TritonGPUAttrDefsIncGen)

set(LLVM_TARGET_DEFINITIONS CTAEncodingAttr.td)
mlir_tablegen(CTAEncodingAttr.h.inc -gen-attrdef-decls)
add_public_tablegen_target(TritonGPUCTAAttrIncGen)

set(LLVM_TARGET_DEFINITIONS TritonGPUTypeInterfaces.td)
mlir_tablegen(TypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(TypeInterfaces.cpp.inc -gen-type-interface-defs)
Expand Down
11 changes: 11 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/CTAEncodingAttr.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#ifndef TRITON_DIALECT_TRITONGPU_IR_CTAENCODINGATTR_H_
#define TRITON_DIALECT_TRITONGPU_IR_CTAENCODINGATTR_H_

#include "mlir/IR/Attributes.h"
#include "triton/Tools/LinearLayout.h"

#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonGPU/IR/CTAEncodingAttr.h.inc"
#undef GET_ATTRDEF_CLASSES

#endif // TRITON_DIALECT_TRITONGPU_IR_CTAENCODINGATTR_H_
40 changes: 40 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/CTAEncodingAttr.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//===----------------------------------------------------------------------===//
// CTA encoding attribute definition emitted early to break interface cycles.
//===----------------------------------------------------------------------===//

#ifndef TRITONGPU_CTAENCODING_ATTR_TD
#define TRITONGPU_CTAENCODING_ATTR_TD

include "triton/Dialect/TritonGPU/IR/TritonGPUAttrBase.td"

//===----------------------------------------------------------------------===//
// CTA Layout
//===----------------------------------------------------------------------===//

def CTAEncodingAttr : TritonGPU_Attr<"CTAEncoding", "cta_encoding"> {
let parameters = (ins LinearLayoutParam:$linearLayout);

let description = [{
Describes how blocks (CTAs) in a cooperative thread array (CGA) map onto logical
tensor dimensions. The `LinearLayout` maps from `block` into `dim0`, `dim1`...
}];

let extraClassDeclaration = [{
static CTAEncodingAttr getDefault(MLIRContext *context, int rank);
// Legacy, we should kill this! Note that it is not true in general that
// fromSplitParams(enc.getCTAsPerCGA(), enc.getCTASplitNum(), enc.getCTAOrder()) == enc!!
static CTAEncodingAttr fromSplitParams(MLIRContext *context,
ArrayRef<unsigned> CTAsPerCGA,
ArrayRef<unsigned> CTASplitNum,
ArrayRef<unsigned> CTAOrder);

unsigned getRank() const { return getLinearLayout().getNumOutDims(); }
SmallVector<unsigned> getCTAsPerCGA() const;
SmallVector<unsigned> getCTASplitNum() const;
SmallVector<unsigned> getCTAOrder() const;
}];

let genVerifyDecl = 1;
}

#endif // TRITONGPU_CTAENCODING_ATTR_TD
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ inline SmallVector<unsigned> getThreadOrder(RankedTensorType type) {
type.getShape());
}

CTALayoutAttr getCTALayout(Attribute layout);
CTAEncodingAttr getCTALayout(Attribute layout);

SmallVector<unsigned> getCTAsPerCGA(Attribute layout);

Expand Down
8 changes: 0 additions & 8 deletions include/triton/Dialect/TritonGPU/IR/LayoutUtility.h

This file was deleted.

19 changes: 4 additions & 15 deletions include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class SwizzledSharedEncodingAttr;
class NVMMASharedEncodingAttr;
class TensorOrMemDesc;
class MemDescType;
class CTALayoutAttr;
class CTAEncodingAttr;

// - BlockedEncodingAttrs have the following input dimensions.
//
Expand Down Expand Up @@ -77,9 +77,9 @@ LinearLayout getLayoutWithinBlock(const LinearLayout &layout);
// given shape.
//
// See the nomenclature note at the top of LinearLayoutConversions.cpp for why
// the variable with type CTALayoutAttr is called cgaLayoutAttr.
// the variable with type CTAEncodingAttr is called cgaLayoutAttr.
LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
CTALayoutAttr cgaLayoutAttr,
CTAEncodingAttr cgaLayoutAttr,
ArrayRef<int64_t> shape);

// In this function, we construct a linear layout representing the
Expand Down Expand Up @@ -133,7 +133,7 @@ LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx,
ArrayRef<int64_t> shape, int opIdx,
ArrayRef<unsigned> warpsPerCTA,
CTALayoutAttr ctaLayout);
CTAEncodingAttr ctaLayout);

// Create LinearLayout for nvidia mma tile.
LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
Expand All @@ -149,16 +149,5 @@ std::optional<LinearLayout> chooseMfmaLikeStoreLayout(RankedTensorType valType);
// Create the core layout (atom in the PTX manual) a given nvmma shared encoding
LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared,
bool disableSwizzle);

// Make a LinearLayout that maps a block-id to an N-dimensional index.
//
// The tensor is split up into CTAsPerCGA pieces, which are distributed among
// the CTAsPerCGA CTAs (i.e. blocks) in the CGA (i.e. groups).
//
// See the nomenclature note at the top of the LinearLayoutConversions.cpp file
// for an explanation of why this is called makeCgaLayout when it accepts a
// CTALayoutAttr.
LinearLayout makeCgaLayout(CTALayoutAttr layout);

} // namespace mlir::triton::gpu
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H
54 changes: 54 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrBase.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
//===----------------------------------------------------------------------===//
// Base definitions shared by TritonGPU attribute TableGen files.
// Splitting these out lets us emit certain attributes (e.g. CTAEncodingAttr)
// before interface headers without creating circular dependencies.
//===----------------------------------------------------------------------===//

#ifndef TRITONGPU_ATTRBASE_TD
#define TRITONGPU_ATTRBASE_TD

include "mlir/IR/AttrTypeBase.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"

// Traits used across several attrs.
def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">;

// Common parameter helpers.
def LinearLayoutParam : AttrOrTypeParameter<"LinearLayout",
"linear layout"> {
let cppAccessorType = "const LinearLayout &";
}

// Base class for all TritonGPU attributes.
class TritonGPU_Attr<string name, string attrMnemonic, list<Trait> traits = []>
: AttrDef<TritonGPU_Dialect, name, traits> {

let description = [{
TritonGPU tensors differ from usual tensors in that they contain a _layout_ attribute which determines
how the data should be partitioned across CUDA threads. Formally speaking, we define a layout as a function
\mathcal{L} that maps a multi-dimensional tensor index $i \in \mathbb{Z}^d$ to a set of integers T corresponding
to the indices of the CUDA threads allowed to access some data at index $i$.

For example, let us consider the layout function:
\mathcal{L}(0, 0) = {0, 4}
\mathcal{L}(0, 1) = {1, 5}
\mathcal{L}(1, 0) = {2, 6}
\mathcal{L}(1, 1) = {3, 7}

Then, attaching $\mathcal{L} to a tensor $T$ would mean that:
- T[0,0] is owned by both cuda thread 0 and 4
- T[0,1] is owned by both cuda thread 1 and 5
- T[1,0] is owned by both cuda thread 2 and 6
- T[1,1] is owned by both cuda thread 3 and 7

Right now, Triton implements two main classes of layouts: shared, and distributed.
}];
let attrName = "triton.gpu." # attrMnemonic;

code extraBaseClassDeclaration = [{
}];
}

#endif // TRITONGPU_ATTRBASE_TD
Loading
Loading