@@ -28,7 +28,6 @@ struct MMASMEMDescriptor {
2828 SMEMDescriptor descriptor;
2929 int32_t swizzlingByteWidth;
3030 int32_t bitwidth;
31- bool twoCTAs;
3231 bool transposed;
3332 bool fp4Padded;
3433};
@@ -53,77 +52,67 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
5352
5453 DotOpMmaSmemLoader (MMASMEMDescriptor desc, Value baseb128, LinearLayout llInv,
5554 ArrayRef<unsigned > instrShape)
56- : desc(desc), baseb128(baseb128), llInv (std::move(llInv)),
55+ : desc(desc), baseb128(baseb128), ll (std::move(llInv)),
5756 instrShape (instrShape) {}
5857
5958 static DotOpMmaSmemLoader
60- build (Location loc, RewriterBase &rewriter, triton:: gpu::MemDescType tensor ,
59+ build (Location loc, RewriterBase &rewriter, gpu::MemDescType memTy ,
6160 Value smemBase, ArrayRef<unsigned > instrShape, int mmaVersion,
61+ bool isFp4 = false ,
6262 std::optional<RankedTensorType> mmaTy = std::nullopt ,
6363 std::optional<unsigned > MNdim = std::nullopt ) {
64- auto ctx = tensor.getContext ();
64+ auto ctx = rewriter.getContext ();
65+ auto kOffset = str_attr (" offset" );
66+ // The handling of subviews is not as fine as it could be
67+ // We could compose with the identity of the memTy.getShape()
68+ // (at the moment llInv will be of allocShape), but then
69+ // we would need to handle the getReps part more carefuly
70+ // This way we could support more subviews that we don't
71+ // We can implement this generalisation in the future if needed
72+ auto llInv = toLinearLayout (memTy).pseudoinvert ();
73+ auto bitwidth = memTy.getElementType ().getIntOrFloatBitWidth ();
74+ if (isFp4) {
75+ // hacky but well
76+ auto dims = to_vector (llInv.getInDimNames ());
77+ auto trans = llInv.getBasis (dims[0 ], 0 , kOffset ) == 1 ;
78+ llInv = LinearLayout::identity1D (2 , dims[trans ? 0 : 1 ], kOffset ) * llInv;
79+ bitwidth /= 2 ;
80+ // The instr_shape comes in number of elements already
81+ }
82+ return build (loc, rewriter, llInv, bitwidth, smemBase, instrShape,
83+ mmaVersion, mmaTy, MNdim);
84+ }
85+
86+ static DotOpMmaSmemLoader
87+ build (Location loc, RewriterBase &rewriter, const LinearLayout &ll,
88+ int bitwidth, Value smemBase, ArrayRef<unsigned > instrShapeArray,
89+ int mmaVersion, std::optional<RankedTensorType> mmaTy = std::nullopt ,
90+ std::optional<unsigned > MNdim = std::nullopt ) {
91+ // ll is a map from two dimensions (dim0, dim1) or (row, col) into offsets
92+ // and blocks
93+ auto ctx = rewriter.getContext ();
94+ auto kOffset = str_attr (" offset" );
95+ auto kBlock = str_attr (" block" );
96+ assert (ll.getNumOutDims () == 2 );
97+ assert (ll.hasOutDim (kOffset ) && ll.hasOutDim (kBlock ));
98+
6599 assert (mmaVersion == 3 || mmaVersion == 5 );
66100 // Just needed for MMAv3
67101 assert (mmaTy.has_value () == (mmaVersion == 3 ));
68102 assert (MNdim.has_value () == (mmaVersion == 3 ));
69103 if (mmaVersion == 3 ) {
70104 assert (MNdim.value () < 2 );
71105 }
106+ auto instrShape = to_vector (instrShapeArray);
72107 assert (instrShape.size () == 2 );
73108 auto b = TritonLLVMOpBuilder (loc, rewriter);
74- // TODO Assert that calling getShmemAffineBase is valid!
75-
76- // Due to the alignment, we can transform ((base + offset) & 0x3FFFF) >> 4
77- // into ((base >> 4) & 0x3FFF + (offset >> 4) where offset is in the inner
78- // loop and ((base >> 4) & 0x3FFF) can be computed once.
79- assert (cast<triton::gpu::SharedEncodingTrait>(tensor.getEncoding ())
80- .getAlignment () >= 16 );
81109
110+ // Due to having a 16B alignment, we can compute the offsets in 128b
111+ // elements
112+ // TODO We should assert in the verifier that the alignment is at least 16B
82113 smemBase = b.ptrtoint (i32_ty, smemBase);
83114 Value baseSrcb128 = b.lshr (smemBase, b.i32_val (4 ));
84- int bitwidth = tensor.getElementType ().getIntOrFloatBitWidth ();
85115
86- auto ll = toLinearLayout (tensor);
87- auto kOffset = str_attr (" offset" );
88- assert (ll.getNumOutDims () == 2 );
89- auto dims = to_vector (ll.getOutDimNames ());
90- // The linear layout for fp4 represents the matrix as i8s
91- // For it to play ball with instrShape, which is in terms of the original
92- // tensor, we need to represent it as i4s
93- // Interestingly enough, we support i8 x i8 matmul by the looks of it
94- auto isFp4 =
95- tensor.getElementType () == IntegerType::get (ctx, 8 ) && mmaVersion == 5 ;
96- auto shape = to_vector (tensor.getShape ());
97- if (isFp4) {
98- // hacky but well
99- auto trans = ll.getBasis (kOffset , 0 )[0 ] != 0 ;
100- ll = LinearLayout::identity1D (2 , kOffset , dims[trans ? 0 : 1 ]) * ll;
101- shape[trans ? 0 : 1 ] *= 2 ;
102- bitwidth /= 2 ;
103- // The instr_shape comes in number of elements already
104- }
105-
106- for (auto [dim, instrSize] : llvm::zip (ll.getOutDimNames (), instrShape)) {
107- assert (instrSize <= ll.getOutDimSize (dim) &&
108- " Instr shape is too large for the layout" );
109- }
110-
111- // TODO Add this to the verifier
112- // We represent fp4 padded tensors as i8s
113- auto desc = getDescriptor (ll, instrShape, bitwidth, mmaVersion);
114-
115- // In case it was a subview, we resize it by composing it with the identity
116- // of shape getShape (rather than shape getAllocShape, as toLinearLayout
117- // returns)
118- // Also in the case of mutlicta where the different CTAs have broadcasting
119- // (so no 2-CTA MMA) we effectively need to pseudoinvert. This also achieves
120- // that
121- auto outDims = to_vector (ll.getOutDimNames ());
122- auto identity = LinearLayout::identity1D (shape[0 ], outDims[0 ], outDims[0 ]) *
123- LinearLayout::identity1D (shape[1 ], outDims[1 ], outDims[1 ]);
124- auto llInv = identity.invertAndCompose (ll);
125-
126- auto blockInstrShape = to_vector (instrShape);
127116 if (mmaVersion == 3 ) {
128117 auto mndim = MNdim.value ();
129118 auto mmaLl = gpu::toLinearLayout (mmaTy.value ());
@@ -133,15 +122,15 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
133122 auto mmaWarps = mmaLl.sublayout ({kWarp }, {outDims[mndim]}) *
134123 LinearLayout::identity1D (1 , kWarp , outDims[1 - mndim]);
135124 // Map from warps to offsets in bitwidth elements
136- auto warpToOffset = mmaWarps.compose (llInv );
125+ auto warpToOffset = mmaWarps.compose (ll );
137126 // Map from warps to offsets in 128b elements
138127 auto maybeWarpToOffsetb128 =
139128 divideLeft (warpToOffset,
140129 LinearLayout::zeros1D (1 , kWarp , kOffset , 128 / bitwidth));
141130 assert (maybeWarpToOffsetb128.has_value ());
142131 // zero out the first two warp bases to have a warpgroup to offset map
143- assert (maybeWarpToOffsetb128->getNumOutDims () == 2 );
144132 auto bases = maybeWarpToOffsetb128->getBases ();
133+ assert (maybeWarpToOffsetb128->getNumOutDims () == 2 );
145134 bases[kWarp ][0 ] = {0 , 0 };
146135 bases[kWarp ][1 ] = {0 , 0 };
147136 auto warpGroupToOffsetb128 = LinearLayout (
@@ -152,34 +141,40 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
152141 {{kWarp , warpId}})[0 ]
153142 .second ;
154143 baseSrcb128 = b.add (baseSrcb128, warpStrideb128);
155- // Increase the instruction shape to describe the size at a warp level
156- // A bit hacky but well
144+ // Increase the instruction shape to describe the size at a block level
145+ // as the input just describes it at a warp level
157146 int logwgAlongMN = 0 ;
158147 for (int i = 0 ; i < warpGroupToOffsetb128.getInDimSizeLog2 (kWarp ); i++) {
159148 if (warpGroupToOffsetb128.getBasis (kWarp , i, kOffset ) != 0 ) {
160149 logwgAlongMN++;
161150 }
162151 }
163- blockInstrShape [mndim] *= (1 << logwgAlongMN);
152+ instrShape [mndim] *= (1 << logwgAlongMN);
164153 }
165154
155+ for (auto [dim, instrSize] : llvm::zip (ll.getInDimNames (), instrShape)) {
156+ assert (instrSize <= ll.getInDimSize (dim) &&
157+ " Instruction shape is too large for the layout" );
158+ }
159+
160+ auto desc = getDescriptor (ll, instrShape, bitwidth, mmaVersion);
161+
166162 Value baseb128 = b.zext (i64_ty, b.and_ (baseSrcb128, b.i32_val (0x3FFF )));
167- return DotOpMmaSmemLoader ( desc, baseb128, llInv, blockInstrShape) ;
163+ return { desc, baseb128, ll, instrShape} ;
168164 }
169165
170166 Value smemLoad (int a, int b, ConversionPatternRewriter &rewriter,
171167 Location loc) const {
172168 auto *ctx = loc.getContext ();
173169 auto tb = TritonLLVMOpBuilder (loc, rewriter);
174- auto dims = to_vector (llInv .getInDimNames ());
175- assert ((a + 1 ) * instrShape[0 ] <= llInv .getInDimSize (dims[0 ]));
176- assert ((b + 1 ) * instrShape[1 ] <= llInv .getInDimSize (dims[1 ]));
177- assert (to_vector (llInv .getOutDimNames ()) ==
170+ auto dims = to_vector (ll .getInDimNames ());
171+ assert ((a + 1 ) * instrShape[0 ] <= ll .getInDimSize (dims[0 ]));
172+ assert ((b + 1 ) * instrShape[1 ] <= ll .getInDimSize (dims[1 ]));
173+ assert (to_vector (ll .getOutDimNames ()) ==
178174 llvm::to_vector (
179175 ArrayRef<StringAttr>{str_attr (" offset" ), str_attr (" block" )}));
180- int32_t totalOffElems = llInv
181- .apply ({{dims[0 ], a * instrShape[0 ]},
182- {dims[1 ], b * instrShape[1 ]}})[0 ]
176+ int32_t totalOffElems = ll.apply ({{dims[0 ], a * instrShape[0 ]},
177+ {dims[1 ], b * instrShape[1 ]}})[0 ]
183178 .second ;
184179 int32_t smemByteOffsetb8 = totalOffElems * desc.bitwidth / 8 ;
185180 auto currDesc = desc.descriptor ;
@@ -198,37 +193,22 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
198193 return {smemLoad (a, b, rewriter, loc), std::nullopt };
199194 }
200195
196+ MMASMEMDescriptor &getDescriptor () { return desc; }
197+
201198private:
202199 MMASMEMDescriptor desc;
203200 Value baseb128;
204- LinearLayout llInv ;
201+ LinearLayout ll ;
205202 SmallVector<unsigned > instrShape;
206203
207204 static MMASMEMDescriptor getDescriptor (const LinearLayout &ll,
208205 ArrayRef<unsigned > instrShape,
209206 int bitwidth, int mmaVersion) {
210207 // ll is a map from allocShape into offsets and blocks
211- auto inv = ll.pseudoinvert ();
212- auto dims = to_vector (inv.getInDimNames ());
208+ auto dims = to_vector (ll.getInDimNames ());
213209 auto ctx = dims[0 ].getContext ();
214210 auto kOffset = str_attr (" offset" );
215211
216- // Detect tcgen05.mma.cta_group::2 as having two CTAs that are not
217- // broadcasting
218- auto kBlock = str_attr (" block" );
219- auto twoCTAs = ll.getInDimSize (kBlock ) > 1 &&
220- ll.getBasis (kBlock , 0 ) != ArrayRef<int32_t >({0 , 0 });
221- SmallVector<unsigned > instrShapePerCTA = to_vector (instrShape);
222- if (twoCTAs) {
223- // In 2CTA mode we split the tensor into two CTAs
224- assert (ll.getInDimSize (kBlock ) == 2 );
225- if (ll.getBasis (kBlock , 0 , dims[0 ]) != 0 ) {
226- instrShapePerCTA[0 ] /= 2 ;
227- } else {
228- instrShapePerCTA[1 ] /= 2 ;
229- }
230- }
231-
232212 // Any CTALayout, it's not really used within getCoreMatrixLinearLayout
233213 auto CTALayout = triton::gpu::CTALayoutAttr::getDefault (ctx, 2 );
234214
@@ -242,11 +222,19 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
242222 CTALayout);
243223 auto shmemTile =
244224 getCoreMatrixLinearLayout (shmemEnc, /* disableSwizzle=*/ false );
245- // We unpack the bitwidth == 8 tile
225+ // Rename out dims to match the original layout (in case the dims were
226+ // (row, col))
227+ auto outDims = to_vector (shmemTile.getOutDims ());
228+ outDims[0 ].first = dims[0 ];
229+ outDims[1 ].first = dims[1 ];
230+ shmemTile = LinearLayout (shmemTile.getBases (), outDims,
231+ /* requireSurjective=*/ false );
232+ // unpack the fp4 layout
246233 if (bitwidth == 4 ) {
247234 shmemTile =
248235 LinearLayout::identity1D (2 , kOffset , dims[1 ]) * shmemTile;
249236 }
237+
250238 // getCoreMatrixLinearLayout gives the k-contiguous tile
251239 // shmemTile is a layout onto a matrix with shape
252240 // If swizzling != 0: 8 x (8 * swizzling / bitwidth)
@@ -266,10 +254,12 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
266254 // The PTX docs are wrong in a number of ways:
267255 // 1) LBO can be specified for !transposed && swizzled != 0
268256 // PTX says it's assumed to be 1, but we can in fact use it
269- // 2) LBO / SBO are swapped also for !transposed && swizzled ! = 0
257+ // 2) LBO / SBO are swapped also for !transposed && swizzled = = 0
270258 // PTX just reports this for the transposed case
271- // Luckily enough the generic logic is much simpler than what's
272- // described in the docs
259+ // EVEN MORE the computation we do here is conceptually correct
260+ // and it agrees with the tensor descriptors for wgmma or
261+ // tcgen05.mma but not for tcgen05.cp! tcgen05.cp follows the PTX
262+ // docs!
273263 int lbo = 0 , sbo = 0 ;
274264 int leadingDim = transposed ? 0 : 1 ;
275265 int stridedDim = transposed ? 1 : 0 ;
@@ -279,13 +269,13 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
279269 std::swap (leadingDim, stridedDim);
280270 }
281271 auto log2RowsTile = shmemTileInv.getInDimSizeLog2 (dims[leadingDim]);
282- if (inv. getInDimSizeLog2 (dims [leadingDim]) > log2RowsTile) {
283- lbo = inv .getBasis (dims[leadingDim], log2RowsTile, kOffset );
272+ if (llvm::Log2_32 (instrShape [leadingDim]) > log2RowsTile) {
273+ lbo = ll .getBasis (dims[leadingDim], log2RowsTile, kOffset );
284274 }
285275
286276 auto log2ColsTile = shmemTileInv.getInDimSizeLog2 (dims[stridedDim]);
287- if (inv. getInDimSizeLog2 (dims [stridedDim]) > log2ColsTile) {
288- sbo = inv .getBasis (dims[stridedDim], log2ColsTile, kOffset );
277+ if (llvm::Log2_32 (instrShape [stridedDim]) > log2ColsTile) {
278+ sbo = ll .getBasis (dims[stridedDim], log2ColsTile, kOffset );
289279 }
290280
291281 // Pad the tile up to the full instruction shape with the relevant
@@ -294,9 +284,9 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
294284 for (int d : {0 , 1 }) {
295285 // 'tile' with the atom tile according to the lbo/sbo rules
296286 for (int i = 1 ;
297- i < instrShapePerCTA [d] / shmemTileInv.getInDimSize (dims[d]);
287+ i < instrShape [d] / shmemTileInv.getInDimSize (dims[d]);
298288 i *= 2 ) {
299- auto stride = inv .getBasis (
289+ auto stride = ll .getBasis (
300290 dims[d], shmemTileInv.getInDimSizeLog2 (dims[d]), kOffset );
301291 bases[dims[d]].push_back ({stride * i});
302292 }
@@ -316,11 +306,11 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
316306 shmemTileInv *=
317307 LinearLayout::identity1D (1 , dims[0 ], str_attr (" block" ));
318308
319- auto quot = getReps (inv , shmemTileInv);
320- if (quot .has_value ()) {
309+ auto reps = getReps (ll , shmemTileInv);
310+ if (reps .has_value ()) {
321311 SMEMDescriptor desc;
322312 desc.descriptor = mmaVersion == 5 ? 1ULL << 46 : 0ULL ;
323- // The lbo / sbo is defined wrt. the 128 tile
313+ // The lbo / sbo is defined wrt. the 128b elements
324314 desc.leadDimensionBaseOffset = (lbo * bitwidth / 8 ) >> 4 ;
325315 desc.strideDimensionBaseOffset = (sbo * bitwidth / 8 ) >> 4 ;
326316 switch (swizzling) {
@@ -342,7 +332,6 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
342332 return {.descriptor = desc,
343333 .swizzlingByteWidth = swizzling,
344334 .bitwidth = bitwidth,
345- .twoCTAs = twoCTAs,
346335 .transposed = transposed,
347336 .fp4Padded = fp4Padded};
348337 }
0 commit comments