@@ -110,6 +110,167 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
110
110
return success ();
111
111
}
112
112
113
+ SmallVector<Value> transferWithinBlockSwizzlingImpl (
114
+ Location loc, ConversionPatternRewriter &rewriter,
115
+ const LinearLayout &srcLayout, const LinearLayout &dstLayout,
116
+ ArrayRef<Value> inVals, Type llvmElemTy, Value smemBase) const {
117
+ auto *ctx = rewriter.getContext ();
118
+ auto b = TritonLLVMOpBuilder (loc, rewriter);
119
+ // We handle transformations recursively as they all need a preprocessing
120
+ // and a postprocessing step.
121
+
122
+ // Handle pointer types as 64-bit integers
123
+ if (isa<LLVM::LLVMPointerType>(llvmElemTy)) {
124
+ auto llvmElemTyPtr = i64_ty;
125
+ auto newInVals = llvm::to_vector (llvm::map_range (inVals, [&](Value v) {
126
+ return b.ptrtoint (llvmElemTyPtr, v).getResult ();
127
+ }));
128
+ auto outVals =
129
+ transferWithinBlockSwizzlingImpl (loc, rewriter, srcLayout, dstLayout,
130
+ newInVals, llvmElemTyPtr, smemBase);
131
+ for (auto &v : outVals) {
132
+ v = b.inttoptr (llvmElemTy, v);
133
+ }
134
+ return outVals;
135
+ }
136
+
137
+ // Handle sub-byte elements like i1
138
+ if (llvmElemTy.getIntOrFloatBitWidth () < 8 ) {
139
+ // Upcast to i8
140
+ auto i8ElemTy = i8_ty;
141
+ auto newInVals = llvm::to_vector (llvm::map_range (
142
+ inVals, [&](Value v) { return b.zext (i8ElemTy, v).getResult (); }));
143
+ auto outVals = transferWithinBlockSwizzlingImpl (
144
+ loc, rewriter, srcLayout, dstLayout, newInVals, i8ElemTy, smemBase);
145
+ for (auto &v : outVals) {
146
+ v = b.trunc (llvmElemTy, v);
147
+ }
148
+ return outVals;
149
+ }
150
+
151
+ // Remove broadcasting in src
152
+ auto removeBroadcastSrc = actionRemoveBroadcastedRegs (srcLayout);
153
+ if (!removeBroadcastSrc.isIdentity ()) {
154
+ auto prmtSrc = removeBroadcastSrc.apply (srcLayout);
155
+ auto newInVals = removeBroadcastSrc.apply (inVals);
156
+ return transferWithinBlockSwizzlingImpl (loc, rewriter, prmtSrc, dstLayout,
157
+ newInVals, llvmElemTy, smemBase);
158
+ }
159
+
160
+ // Remove broadcasting in dst
161
+ auto removeBroadcastDst = actionRemoveBroadcastedRegs (dstLayout);
162
+ if (!removeBroadcastDst.isIdentity ()) {
163
+ auto prmtDst = removeBroadcastDst.apply (dstLayout);
164
+ auto outVals = transferWithinBlockSwizzlingImpl (
165
+ loc, rewriter, srcLayout, prmtDst, inVals, llvmElemTy, smemBase);
166
+ return broadcastAs (outVals, dstLayout);
167
+ }
168
+
169
+ // At this point we have a type that's at least 8-bit
170
+ // and we don't have broadcasting in the registers
171
+ auto bitwidth = llvmElemTy.getIntOrFloatBitWidth ();
172
+ auto smem = optimalSwizzling (srcLayout, dstLayout, bitwidth);
173
+
174
+ // Extract reps from smem
175
+ auto kReg = str_attr (" register" );
176
+ auto kReps = str_attr (" reps" );
177
+ auto nReps = smem.getInDimSize (kReps );
178
+ auto reps = LinearLayout::identity1D (nReps, kReg , kReps );
179
+
180
+ auto totalStoreCvt = srcLayout.invertAndCompose (smem);
181
+ auto totalLoadCvt = dstLayout.invertAndCompose (smem);
182
+
183
+ // The permutation exists by construction of the reps dimension in
184
+ // optimalSwizzling
185
+ auto permStore =
186
+ regPermForDivide (totalStoreCvt, reps, /* left=*/ false ).value ();
187
+ totalStoreCvt = permStore.apply (totalStoreCvt);
188
+ auto permutedInVals = permStore.apply (inVals);
189
+ auto permLoad =
190
+ regPermForDivide (totalLoadCvt, reps, /* left=*/ false ).value ();
191
+ totalLoadCvt = permLoad.apply (totalLoadCvt);
192
+
193
+ // Remove the reps and flatten into offset
194
+ auto storeCvt = *divideRight (totalStoreCvt, reps);
195
+ auto loadCvt = *divideRight (totalLoadCvt, reps);
196
+ auto kOffset = str_attr (" offset" );
197
+ storeCvt = storeCvt.reshapeOuts ({{kOffset , storeCvt.getTotalOutDimSize ()}});
198
+ loadCvt = loadCvt.reshapeOuts ({{kOffset , loadCvt.getTotalOutDimSize ()}});
199
+
200
+ auto tileSize = storeCvt.getInDimSize (kReg );
201
+
202
+ assert (permutedInVals.size () == tileSize * nReps);
203
+ SmallVector<Value> outVals;
204
+ auto noPaddingOffset = [](Value v) { return v; };
205
+ auto affineOffset = b.i32_val (0 );
206
+ auto maskSpanAffineOffset = 0 ;
207
+ for (int i = 0 ; i < nReps; ++i) {
208
+ if (i > 0 )
209
+ b.barrier ();
210
+
211
+ auto tileInVals =
212
+ ArrayRef<Value>(permutedInVals).slice (i * tileSize, tileSize);
213
+ // Store
214
+ lowerLdStShared (loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase,
215
+ noPaddingOffset, affineOffset, maskSpanAffineOffset,
216
+ rewriter, targetInfo);
217
+ b.barrier ();
218
+ // Load
219
+ SmallVector<Value> tileOutVals = lowerLdStShared (
220
+ loc, ctx, loadCvt, {}, llvmElemTy, smemBase, noPaddingOffset,
221
+ affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
222
+ llvm::append_range (outVals, tileOutVals);
223
+ }
224
+
225
+ // Undo the permLoad used to divideRight
226
+ outVals = permLoad.inverse ().apply (outVals);
227
+ return outVals;
228
+ }
229
+
230
+ LogicalResult
231
+ transferWithinBlockSwizzling (ConvertLayoutOp op, Value src,
232
+ ConversionPatternRewriter &rewriter) const {
233
+ // Fallback for now to standard lowering if it can use stmatrix
234
+ auto scratchConfig =
235
+ getScratchConfigForCvt (op.getSrc ().getType (), op.getType ());
236
+ bool isStMatrix = targetInfo.canUseStMatrix (
237
+ op.getSrc ().getType (), scratchConfig.repShape ,
238
+ scratchConfig.paddedRepShape , scratchConfig.order ,
239
+ /* swizzleByteSize=*/ 0 );
240
+ if (isStMatrix) {
241
+ return failure ();
242
+ }
243
+
244
+ auto loc = op.getLoc ();
245
+ auto *ctx = op.getContext ();
246
+ auto srcTy = op.getSrc ().getType ();
247
+ auto dstTy = op.getType ();
248
+
249
+ // Remove the kBlock dimension from the layout as it's the identity in the
250
+ // cvt
251
+ auto srcLayout = toLinearLayout (srcTy);
252
+ auto dstLayout = toLinearLayout (dstTy);
253
+ auto kReg = str_attr (" register" );
254
+ auto kLane = str_attr (" lane" );
255
+ auto kWarp = str_attr (" warp" );
256
+ srcLayout = srcLayout.sublayout ({kReg , kLane , kWarp },
257
+ to_vector (srcLayout.getOutDimNames ()));
258
+ dstLayout = dstLayout.sublayout ({kReg , kLane , kWarp },
259
+ to_vector (dstLayout.getOutDimNames ()));
260
+
261
+ auto llvmElemTy = getTypeConverter ()->convertType (srcTy.getElementType ());
262
+ auto smemBase =
263
+ LLVM::getSharedMemoryBase (loc, rewriter, targetInfo, op.getOperation ());
264
+ auto inVals = unpackLLElements (loc, src, rewriter);
265
+ auto outVals = transferWithinBlockSwizzlingImpl (
266
+ loc, rewriter, srcLayout, dstLayout, inVals, llvmElemTy, smemBase);
267
+
268
+ Value result =
269
+ packLLElements (loc, getTypeConverter (), outVals, rewriter, dstTy);
270
+ rewriter.replaceOp (op, result);
271
+ return success ();
272
+ }
273
+
113
274
LogicalResult transferWithinBlock (ConvertLayoutOp op,
114
275
const LinearLayout &srcLayout,
115
276
const LinearLayout &dstLayout,
@@ -118,8 +279,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
118
279
assert (cvtNeedsSharedMemory (op.getSrc ().getType (), op.getType ()));
119
280
120
281
// Try to use swizzling to implement the conversion
121
- if (succeeded (transferWithinBlockSwizzling (op, adaptor.getSrc (), targetInfo,
122
- getTypeConverter (), rewriter))) {
282
+ // HACK Remove once AMD tests pass for the swizzling path
283
+ if (targetInfo.isCuda () && succeeded (transferWithinBlockSwizzling (
284
+ op, adaptor.getSrc (), rewriter))) {
123
285
return success ();
124
286
}
125
287
0 commit comments