Skip to content

Commit 82d5122

Browse files
authored
Reapply "[LAYOUTS] Make CTALayout an honest-to-goodness LinearLayout (#8770)" (#5582)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 5c4ddd5 commit 82d5122

File tree

91 files changed

+1242
-1449
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

91 files changed

+1242
-1449
lines changed

bin/triton-tensor-layout.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ using namespace mlir;
2222
// clang-format off
2323
// Example usage:
2424
//
25-
// triton-tensor-layout -l "#ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>"
25+
// triton-tensor-layout -l "#ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>"
2626
//
2727
// triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt
2828
//

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,6 @@ using ::mlir::LLVM::delinearize;
489489
using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
490490
using ::mlir::triton::gpu::AMDWmmaEncodingAttr;
491491
using ::mlir::triton::gpu::BlockedEncodingAttr;
492-
using ::mlir::triton::gpu::CTALayoutAttr;
493492
using ::mlir::triton::gpu::DotOperandEncodingAttr;
494493
using ::mlir::triton::gpu::NvidiaMmaEncodingAttr;
495494
using ::mlir::triton::gpu::SliceEncodingAttr;

include/triton/Dialect/TritonGPU/IR/Attributes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_
33

44
#include "mlir/IR/Attributes.h"
5+
#include "triton/Dialect/TritonGPU/IR/CTAEncodingAttr.h"
56
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
67

78
#define GET_ATTRDEF_CLASSES

include/triton/Dialect/TritonGPU/IR/CMakeLists.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,17 @@ set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td)
1515
mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls)
1616
mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs)
1717
mlir_tablegen(AttrDefs.h.inc -gen-attrdef-decls)
18-
mlir_tablegen(AttrDefs.cpp.inc -gen-attrdef-defs)
1918
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
2019
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
20+
21+
set(LLVM_TARGET_DEFINITIONS TritonGPUAttrImpls.td)
22+
mlir_tablegen(AttrDefs.cpp.inc -gen-attrdef-defs)
2123
add_public_tablegen_target(TritonGPUAttrDefsIncGen)
2224

25+
set(LLVM_TARGET_DEFINITIONS CTAEncodingAttr.td)
26+
mlir_tablegen(CTAEncodingAttr.h.inc -gen-attrdef-decls)
27+
add_public_tablegen_target(TritonGPUCTAAttrIncGen)
28+
2329
set(LLVM_TARGET_DEFINITIONS TritonGPUTypeInterfaces.td)
2430
mlir_tablegen(TypeInterfaces.h.inc -gen-type-interface-decls)
2531
mlir_tablegen(TypeInterfaces.cpp.inc -gen-type-interface-defs)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#ifndef TRITON_DIALECT_TRITONGPU_IR_CTAENCODINGATTR_H_
2+
#define TRITON_DIALECT_TRITONGPU_IR_CTAENCODINGATTR_H_
3+
4+
#include "mlir/IR/Attributes.h"
5+
#include "triton/Tools/LinearLayout.h"
6+
7+
#define GET_ATTRDEF_CLASSES
8+
#include "triton/Dialect/TritonGPU/IR/CTAEncodingAttr.h.inc"
9+
#undef GET_ATTRDEF_CLASSES
10+
11+
#endif // TRITON_DIALECT_TRITONGPU_IR_CTAENCODINGATTR_H_
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//===----------------------------------------------------------------------===//
2+
// CTA encoding attribute definition emitted early to break interface cycles.
3+
//===----------------------------------------------------------------------===//
4+
5+
#ifndef TRITONGPU_CTAENCODING_ATTR_TD
6+
#define TRITONGPU_CTAENCODING_ATTR_TD
7+
8+
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrBase.td"
9+
10+
//===----------------------------------------------------------------------===//
11+
// CTA Layout
12+
//===----------------------------------------------------------------------===//
13+
14+
def CTAEncodingAttr : TritonGPU_Attr<"CTAEncoding", "cta_encoding"> {
15+
let parameters = (ins LinearLayoutParam:$linearLayout);
16+
17+
let description = [{
18+
Describes how blocks (CTAs) in a cooperative thread array (CGA) map onto logical
19+
tensor dimensions. The `LinearLayout` maps from `block` into `dim0`, `dim1`...
20+
}];
21+
22+
let extraClassDeclaration = [{
23+
static CTAEncodingAttr getDefault(MLIRContext *context, int rank);
24+
// Legacy, we should kill this! Note that it is not true in general that
25+
// fromSplitParams(enc.getCTAsPerCGA(), enc.getCTASplitNum(), enc.getCTAOrder()) == enc!!
26+
static CTAEncodingAttr fromSplitParams(MLIRContext *context,
27+
ArrayRef<unsigned> CTAsPerCGA,
28+
ArrayRef<unsigned> CTASplitNum,
29+
ArrayRef<unsigned> CTAOrder);
30+
31+
unsigned getRank() const { return getLinearLayout().getNumOutDims(); }
32+
SmallVector<unsigned> getCTAsPerCGA() const;
33+
SmallVector<unsigned> getCTASplitNum() const;
34+
SmallVector<unsigned> getCTAOrder() const;
35+
}];
36+
37+
let genVerifyDecl = 1;
38+
}
39+
40+
#endif // TRITONGPU_CTAENCODING_ATTR_TD

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ inline SmallVector<unsigned> getThreadOrder(RankedTensorType type) {
210210
type.getShape());
211211
}
212212

213-
CTALayoutAttr getCTALayout(Attribute layout);
213+
CTAEncodingAttr getCTALayout(Attribute layout);
214214

215215
SmallVector<unsigned> getCTAsPerCGA(Attribute layout);
216216

include/triton/Dialect/TritonGPU/IR/LayoutUtility.h

Lines changed: 0 additions & 8 deletions
This file was deleted.

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class SwizzledSharedEncodingAttr;
1717
class NVMMASharedEncodingAttr;
1818
class TensorOrMemDesc;
1919
class MemDescType;
20-
class CTALayoutAttr;
20+
class CTAEncodingAttr;
2121

2222
// - BlockedEncodingAttrs have the following input dimensions.
2323
//
@@ -77,9 +77,9 @@ LinearLayout getLayoutWithinBlock(const LinearLayout &layout);
7777
// given shape.
7878
//
7979
// See the nomenclature note at the top of LinearLayoutConversions.cpp for why
80-
// the variable with type CTALayoutAttr is called cgaLayoutAttr.
80+
// the variable with type CTAEncodingAttr is called cgaLayoutAttr.
8181
LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
82-
CTALayoutAttr cgaLayoutAttr,
82+
CTAEncodingAttr cgaLayoutAttr,
8383
ArrayRef<int64_t> shape);
8484

8585
// In this function, we construct a linear layout representing the
@@ -133,7 +133,7 @@ LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
133133
LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx,
134134
ArrayRef<int64_t> shape, int opIdx,
135135
ArrayRef<unsigned> warpsPerCTA,
136-
CTALayoutAttr ctaLayout);
136+
CTAEncodingAttr ctaLayout);
137137

138138
// Create LinearLayout for nvidia mma tile.
139139
LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
@@ -149,16 +149,5 @@ std::optional<LinearLayout> chooseMfmaLikeStoreLayout(RankedTensorType valType);
149149
// Create the core layout (atom in the PTX manual) a given nvmma shared encoding
150150
LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared,
151151
bool disableSwizzle);
152-
153-
// Make a LinearLayout that maps a block-id to an N-dimensional index.
154-
//
155-
// The tensor is split up into CTAsPerCGA pieces, which are distributed among
156-
// the CTAsPerCGA CTAs (i.e. blocks) in the CGA (i.e. groups).
157-
//
158-
// See the nomenclature note at the top of the LinearLayoutConversions.cpp file
159-
// for an explanation of why this is called makeCgaLayout when it accepts a
160-
// CTALayoutAttr.
161-
LinearLayout makeCgaLayout(CTALayoutAttr layout);
162-
163152
} // namespace mlir::triton::gpu
164153
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
//===----------------------------------------------------------------------===//
2+
// Base definitions shared by TritonGPU attribute TableGen files.
3+
// Splitting these out lets us emit certain attributes (e.g. CTAEncodingAttr)
4+
// before interface headers without creating circular dependencies.
5+
//===----------------------------------------------------------------------===//
6+
7+
#ifndef TRITONGPU_ATTRBASE_TD
8+
#define TRITONGPU_ATTRBASE_TD
9+
10+
include "mlir/IR/AttrTypeBase.td"
11+
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
12+
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
13+
14+
// Traits used across several attrs.
15+
def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
16+
def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">;
17+
18+
// Common parameter helpers.
19+
def LinearLayoutParam : AttrOrTypeParameter<"LinearLayout",
20+
"linear layout"> {
21+
let cppAccessorType = "const LinearLayout &";
22+
}
23+
24+
// Base class for all TritonGPU attributes.
25+
class TritonGPU_Attr<string name, string attrMnemonic, list<Trait> traits = []>
26+
: AttrDef<TritonGPU_Dialect, name, traits> {
27+
28+
let description = [{
29+
TritonGPU tensors differ from usual tensors in that they contain a _layout_ attribute which determines
30+
how the data should be partitioned across CUDA threads. Formally speaking, we define a layout as a function
31+
\mathcal{L} that maps a multi-dimensional tensor index $i \in \mathbb{Z}^d$ to a set of integers T corresponding
32+
to the indices of the CUDA threads allowed to access some data at index $i$.
33+
34+
For example, let us consider the layout function:
35+
\mathcal{L}(0, 0) = {0, 4}
36+
\mathcal{L}(0, 1) = {1, 5}
37+
\mathcal{L}(1, 0) = {2, 6}
38+
\mathcal{L}(1, 1) = {3, 7}
39+
40+
Then, attaching $\mathcal{L} to a tensor $T$ would mean that:
41+
- T[0,0] is owned by both cuda thread 0 and 4
42+
- T[0,1] is owned by both cuda thread 1 and 5
43+
- T[1,0] is owned by both cuda thread 2 and 6
44+
- T[1,1] is owned by both cuda thread 3 and 7
45+
46+
Right now, Triton implements two main classes of layouts: shared, and distributed.
47+
}];
48+
let attrName = "triton.gpu." # attrMnemonic;
49+
50+
code extraBaseClassDeclaration = [{
51+
}];
52+
}
53+
54+
#endif // TRITONGPU_ATTRBASE_TD

0 commit comments

Comments
 (0)