@@ -28,118 +28,18 @@ inline SmallVector<Value> translateTMAIndices(BuilderT &builder, Location loc,
28
28
return indices;
29
29
}
30
30
31
- inline gpu::CTALayoutAttr updateCTALayoutForShape (gpu::CTALayoutAttr ctaLayout,
32
- ArrayRef<int64_t > shape) {
33
- auto rank = shape.size ();
34
- if (ctaLayout.getRank () == rank)
35
- return ctaLayout;
31
+ gpu::CTALayoutAttr updateCTALayoutForShape (gpu::CTALayoutAttr ctaLayout,
32
+ ArrayRef<int64_t > shape);
36
33
37
- auto ctx = ctaLayout.getContext ();
38
- if (ctaLayout.getRank () > rank) {
39
- unsigned rankDiff = ctaLayout.getRank () - rank;
40
- return gpu::CTALayoutAttr::get (
41
- ctx, ctaLayout.getCTAsPerCGA ().drop_front (rankDiff),
42
- ctaLayout.getCTASplitNum ().drop_front (rankDiff),
43
- ctaLayout.getCTAOrder ().drop_front (rankDiff));
44
- }
45
- // For rank-reducing loads, we need to rank-increase the CTA Layout
46
- auto rankDiff = rank - ctaLayout.getRank ();
47
- for (unsigned i = 0 ; i < rankDiff; ++i) {
48
- assert (shape[i] == 1 && " Should only happen for rank-reducing loads" );
49
- }
50
- SmallVector<unsigned > CTAsPerCGA (rank, 1 );
51
- SmallVector<unsigned > CTASplitNum (rank, 1 );
52
- SmallVector<unsigned > CTAOrder (rank, 1 );
53
-
54
- llvm::copy (ctaLayout.getCTAsPerCGA (), CTAsPerCGA.begin () + rankDiff);
55
- llvm::copy (ctaLayout.getCTASplitNum (), CTASplitNum.begin () + rankDiff);
56
- for (unsigned i = 0 ; i < rankDiff; ++i) {
57
- CTAOrder[i] = rank - i;
58
- }
59
- llvm::copy (ctaLayout.getCTAOrder (), CTAOrder.begin () + rankDiff);
60
- return gpu::CTALayoutAttr::get (ctx, CTAsPerCGA, CTASplitNum, CTAOrder);
61
- }
62
-
63
- inline gpu::SharedEncodingTrait
34
+ gpu::SharedEncodingTrait
64
35
updateEncodingForShape (Operation *op, gpu::SharedEncodingTrait encoding,
65
- RankedTensorType tensorType) {
66
- auto ctx = encoding.getContext ();
67
- auto ctaLayout = gpu::getCTALayout (encoding);
68
- if (auto nvmmaEnc = dyn_cast<gpu::NVMMASharedEncodingAttr>(encoding)) {
69
- auto existingCta = nvmmaEnc.getCTALayout ();
70
- if (!existingCta)
71
- return nvmmaEnc;
72
-
73
- auto newCtaEnc = updateCTALayoutForShape (ctaLayout, tensorType.getShape ());
74
- return gpu::NVMMASharedEncodingAttr::get (
75
- ctx, nvmmaEnc.getSwizzlingByteWidth (), nvmmaEnc.getTransposed (),
76
- nvmmaEnc.getElementBitWidth (), nvmmaEnc.getFp4Padded (), newCtaEnc);
77
- }
78
- if (auto swizEnc = dyn_cast<gpu::SwizzledSharedEncodingAttr>(encoding)) {
79
- auto existingCta = swizEnc.getCTALayout ();
80
- if (!existingCta)
81
- return swizEnc;
82
-
83
- auto rank = tensorType.getRank ();
84
- auto oldOrder = swizEnc.getOrder ();
85
- SmallVector<unsigned > order;
86
- for (int i = 0 ; i + oldOrder.size () < rank; ++i)
87
- order.push_back (rank - i - 1 );
88
- for (int i = 0 ; i < oldOrder.size (); ++i) {
89
- // If it is a rank-reducing load, we need to drop the last dimensions.
90
- if (oldOrder[i] >= rank)
91
- continue ;
92
- order.push_back (oldOrder[i]);
93
- }
94
- auto newCtaEnc = updateCTALayoutForShape (ctaLayout, tensorType.getShape ());
95
- return gpu::SwizzledSharedEncodingAttr::get (
96
- ctx, swizEnc.getVec (), swizEnc.getPerPhase (), swizEnc.getMaxPhase (),
97
- order, newCtaEnc);
98
- }
99
-
100
- constexpr auto msg = " Internal Error: Unhandled tensor descriptor encoding" ;
101
- if (op)
102
- op->emitError () << msg;
103
- llvm::report_fatal_error (msg);
104
- }
36
+ RankedTensorType tensorType);
105
37
106
- inline triton::gpu::SharedEncodingTrait
38
+ triton::gpu::SharedEncodingTrait
107
39
getEncodingFromDescriptor (Operation *op, RankedTensorType tensorType,
108
- Value desc) {
109
- auto descBlockType = cast<TensorDescType>(desc.getType ()).getBlockType ();
110
- Attribute encoding = descBlockType.getEncoding ();
111
- if (!encoding) {
112
- constexpr auto msg =
113
- " Internal Error: Tensor descriptor should have encoding set" ;
114
- if (op)
115
- op->emitError () << msg;
116
- llvm::report_fatal_error (msg);
117
- }
118
- auto sharedEnc = cast<gpu::SharedEncodingTrait>(encoding);
119
- if (descBlockType.getShape () == tensorType.getShape ())
120
- return sharedEnc;
121
-
122
- return updateEncodingForShape (op, sharedEnc, tensorType);
123
- }
40
+ Value desc);
124
41
125
- inline int64_t getTMAContigDim (Attribute encoding, ArrayRef<int64_t > shape) {
126
- assert (encoding);
127
- auto mmaEncoding =
128
- llvm::dyn_cast_or_null<gpu::NVMMASharedEncodingAttr>(encoding);
129
-
130
- // The bounding box inner dimension must be less than or equal to the
131
- // swizzle size.
132
- // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
133
- // We clamp the block size and the codegen will emit multiple copy
134
- // operations.
135
- if (mmaEncoding) {
136
- auto elemSize = mmaEncoding.getElementBitWidth () / 8 ;
137
- return mmaEncoding.getSwizzlingByteWidth () / elemSize;
138
- }
139
-
140
- auto shapePerCTA = gpu::getShapePerCTA (encoding, shape);
141
- return shapePerCTA.back ();
142
- }
42
+ int64_t getTMAContigDim (Attribute encoding, ArrayRef<int64_t > shape);
143
43
144
44
inline int64_t getTMAContigDim (RankedTensorType tensorType) {
145
45
return getTMAContigDim (tensorType.getEncoding (), tensorType.getShape ());
@@ -149,61 +49,9 @@ inline int64_t getTMAContigDim(gpu::MemDescType memDescType) {
149
49
return getTMAContigDim (memDescType.getEncoding (), memDescType.getShape ());
150
50
}
151
51
152
- inline std::optional<int > getTMASwizzleMode (Operation *op, TensorDescType ty) {
153
- auto encoding = ty.getBlockType ().getEncoding ();
154
- auto mmaEncoding = dyn_cast<gpu::NVMMASharedEncodingAttr>(encoding);
155
- unsigned swizzleBytes = mmaEncoding ? mmaEncoding.getSwizzlingByteWidth () : 0 ;
156
- if (!mmaEncoding) {
157
- auto swizzledEnc = dyn_cast<gpu::SwizzledSharedEncodingAttr>(encoding);
158
- if (!swizzledEnc || swizzledEnc.getVec () != 1 ||
159
- swizzledEnc.getPerPhase () != 1 || swizzledEnc.getMaxPhase () != 1 ) {
160
- if (op)
161
- op->emitError (" Unhandled encoding type" );
162
- return std::nullopt;
163
- }
164
- }
165
-
166
- bool fp4Padded = mmaEncoding && mmaEncoding.getFp4Padded ();
167
- assert (!fp4Padded || swizzleBytes == 128 &&
168
- " elem type .b4x16_p64 supports only 128B swizzling" );
52
+ std::optional<int > getTMASwizzleMode (Operation *op, TensorDescType ty);
169
53
170
- int32_t swizzleMode = 0 ;
171
- if (swizzleBytes == 128 ) {
172
- swizzleMode = 3 ;
173
- } else if (swizzleBytes == 64 ) {
174
- swizzleMode = 2 ;
175
- } else if (swizzleBytes == 32 ) {
176
- swizzleMode = 1 ;
177
- }
178
- return swizzleMode;
179
- }
180
-
181
- inline std::optional<int > getTMAElementType (Operation *op, TensorDescType ty) {
182
- auto encoding = ty.getBlockType ().getEncoding ();
183
- auto mmaEncoding = dyn_cast<gpu::NVMMASharedEncodingAttr>(encoding);
184
- bool fp4Padded = mmaEncoding && mmaEncoding.getFp4Padded ();
185
-
186
- if (fp4Padded)
187
- return 14 ; // .b4x16_p64
188
-
189
- auto elemSize = ty.getBlockType ().getElementTypeBitWidth () / 8 ;
190
- switch (elemSize) {
191
- case 1 :
192
- return 0 ;
193
- case 2 :
194
- return 1 ;
195
- case 4 :
196
- return 2 ;
197
- default :
198
- break ;
199
- }
200
- if (op) {
201
- op->emitError ()
202
- << " Tensor descriptor element type must have size 1, 2, or 4 but got "
203
- << elemSize;
204
- }
205
- return std::nullopt;
206
- }
54
+ std::optional<int > getTMAElementType (Operation *op, TensorDescType ty);
207
55
208
56
template <typename BuilderT>
209
57
mlir::LogicalResult createTMADesc (mlir::Value tmaPtr,
0 commit comments