@@ -110,167 +110,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
110
110
return success ();
111
111
}
112
112
113
- SmallVector<Value> transferWithinBlockSwizzlingImpl (
114
- Location loc, ConversionPatternRewriter &rewriter,
115
- const LinearLayout &srcLayout, const LinearLayout &dstLayout,
116
- ArrayRef<Value> inVals, Type llvmElemTy, Value smemBase) const {
117
- auto *ctx = rewriter.getContext ();
118
- auto b = TritonLLVMOpBuilder (loc, rewriter);
119
- // We handle transformations recursively as they all need a preprocessing
120
- // and a postprocessing step.
121
-
122
- // Handle pointer types as 64-bit integers
123
- if (isa<LLVM::LLVMPointerType>(llvmElemTy)) {
124
- auto llvmElemTyPtr = i64_ty;
125
- auto newInVals = llvm::to_vector (llvm::map_range (inVals, [&](Value v) {
126
- return b.ptrtoint (llvmElemTyPtr, v).getResult ();
127
- }));
128
- auto outVals =
129
- transferWithinBlockSwizzlingImpl (loc, rewriter, srcLayout, dstLayout,
130
- newInVals, llvmElemTyPtr, smemBase);
131
- for (auto &v : outVals) {
132
- v = b.inttoptr (llvmElemTy, v);
133
- }
134
- return outVals;
135
- }
136
-
137
- // Handle sub-byte elements like i1
138
- if (llvmElemTy.getIntOrFloatBitWidth () < 8 ) {
139
- // Upcast to i8
140
- auto i8ElemTy = i8_ty;
141
- auto newInVals = llvm::to_vector (llvm::map_range (
142
- inVals, [&](Value v) { return b.zext (i8ElemTy, v).getResult (); }));
143
- auto outVals = transferWithinBlockSwizzlingImpl (
144
- loc, rewriter, srcLayout, dstLayout, newInVals, i8ElemTy, smemBase);
145
- for (auto &v : outVals) {
146
- v = b.trunc (llvmElemTy, v);
147
- }
148
- return outVals;
149
- }
150
-
151
- // Remove broadcasting in src
152
- auto removeBroadcastSrc = actionRemoveBroadcastedRegs (srcLayout);
153
- if (!removeBroadcastSrc.isIdentity ()) {
154
- auto prmtSrc = removeBroadcastSrc.apply (srcLayout);
155
- auto newInVals = removeBroadcastSrc.apply (inVals);
156
- return transferWithinBlockSwizzlingImpl (loc, rewriter, prmtSrc, dstLayout,
157
- newInVals, llvmElemTy, smemBase);
158
- }
159
-
160
- // Remove broadcasting in dst
161
- auto removeBroadcastDst = actionRemoveBroadcastedRegs (dstLayout);
162
- if (!removeBroadcastDst.isIdentity ()) {
163
- auto prmtDst = removeBroadcastDst.apply (dstLayout);
164
- auto outVals = transferWithinBlockSwizzlingImpl (
165
- loc, rewriter, srcLayout, prmtDst, inVals, llvmElemTy, smemBase);
166
- return broadcastAs (outVals, dstLayout);
167
- }
168
-
169
- // At this point we have a type that's at least 8-bit
170
- // and we don't have broadcasting in the registers
171
- auto bitwidth = llvmElemTy.getIntOrFloatBitWidth ();
172
- auto smem = optimalSwizzling (srcLayout, dstLayout, bitwidth);
173
-
174
- // Extract reps from smem
175
- auto kReg = str_attr (" register" );
176
- auto kReps = str_attr (" reps" );
177
- auto nReps = smem.getInDimSize (kReps );
178
- auto reps = LinearLayout::identity1D (nReps, kReg , kReps );
179
-
180
- auto totalStoreCvt = srcLayout.invertAndCompose (smem);
181
- auto totalLoadCvt = dstLayout.invertAndCompose (smem);
182
-
183
- // The permutation exists by construction of the reps dimension in
184
- // optimalSwizzling
185
- auto permStore =
186
- regPermForDivide (totalStoreCvt, reps, /* left=*/ false ).value ();
187
- totalStoreCvt = permStore.apply (totalStoreCvt);
188
- auto permutedInVals = permStore.apply (inVals);
189
- auto permLoad =
190
- regPermForDivide (totalLoadCvt, reps, /* left=*/ false ).value ();
191
- totalLoadCvt = permLoad.apply (totalLoadCvt);
192
-
193
- // Remove the reps and flatten into offset
194
- auto storeCvt = *divideRight (totalStoreCvt, reps);
195
- auto loadCvt = *divideRight (totalLoadCvt, reps);
196
- auto kOffset = str_attr (" offset" );
197
- storeCvt = storeCvt.reshapeOuts ({{kOffset , storeCvt.getTotalOutDimSize ()}});
198
- loadCvt = loadCvt.reshapeOuts ({{kOffset , loadCvt.getTotalOutDimSize ()}});
199
-
200
- auto tileSize = storeCvt.getInDimSize (kReg );
201
-
202
- assert (permutedInVals.size () == tileSize * nReps);
203
- SmallVector<Value> outVals;
204
- auto noPaddingOffset = [](Value v) { return v; };
205
- auto affineOffset = b.i32_val (0 );
206
- auto maskSpanAffineOffset = 0 ;
207
- for (int i = 0 ; i < nReps; ++i) {
208
- if (i > 0 )
209
- b.barrier ();
210
-
211
- auto tileInVals =
212
- ArrayRef<Value>(permutedInVals).slice (i * tileSize, tileSize);
213
- // Store
214
- lowerLdStShared (loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase,
215
- noPaddingOffset, affineOffset, maskSpanAffineOffset,
216
- rewriter, targetInfo);
217
- b.barrier ();
218
- // Load
219
- SmallVector<Value> tileOutVals = lowerLdStShared (
220
- loc, ctx, loadCvt, {}, llvmElemTy, smemBase, noPaddingOffset,
221
- affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
222
- llvm::append_range (outVals, tileOutVals);
223
- }
224
-
225
- // Undo the permLoad used to divideRight
226
- outVals = permLoad.inverse ().apply (outVals);
227
- return outVals;
228
- }
229
-
230
- LogicalResult
231
- transferWithinBlockSwizzling (ConvertLayoutOp op, Value src,
232
- ConversionPatternRewriter &rewriter) const {
233
- // Fallback for now to standard lowering if it can use stmatrix
234
- auto scratchConfig =
235
- getScratchConfigForCvt (op.getSrc ().getType (), op.getType ());
236
- bool isStMatrix = targetInfo.canUseStMatrix (
237
- op.getSrc ().getType (), scratchConfig.repShape ,
238
- scratchConfig.paddedRepShape , scratchConfig.order ,
239
- /* swizzleByteSize=*/ 0 );
240
- if (isStMatrix) {
241
- return failure ();
242
- }
243
-
244
- auto loc = op.getLoc ();
245
- auto *ctx = op.getContext ();
246
- auto srcTy = op.getSrc ().getType ();
247
- auto dstTy = op.getType ();
248
-
249
- // Remove the kBlock dimension from the layout as it's the identity in the
250
- // cvt
251
- auto srcLayout = toLinearLayout (srcTy);
252
- auto dstLayout = toLinearLayout (dstTy);
253
- auto kReg = str_attr (" register" );
254
- auto kLane = str_attr (" lane" );
255
- auto kWarp = str_attr (" warp" );
256
- srcLayout = srcLayout.sublayout ({kReg , kLane , kWarp },
257
- to_vector (srcLayout.getOutDimNames ()));
258
- dstLayout = dstLayout.sublayout ({kReg , kLane , kWarp },
259
- to_vector (dstLayout.getOutDimNames ()));
260
-
261
- auto llvmElemTy = getTypeConverter ()->convertType (srcTy.getElementType ());
262
- auto smemBase =
263
- LLVM::getSharedMemoryBase (loc, rewriter, targetInfo, op.getOperation ());
264
- auto inVals = unpackLLElements (loc, src, rewriter);
265
- auto outVals = transferWithinBlockSwizzlingImpl (
266
- loc, rewriter, srcLayout, dstLayout, inVals, llvmElemTy, smemBase);
267
-
268
- Value result =
269
- packLLElements (loc, getTypeConverter (), outVals, rewriter, dstTy);
270
- rewriter.replaceOp (op, result);
271
- return success ();
272
- }
273
-
274
113
LogicalResult transferWithinBlock (ConvertLayoutOp op,
275
114
const LinearLayout &srcLayout,
276
115
const LinearLayout &dstLayout,
@@ -279,9 +118,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
279
118
assert (cvtNeedsSharedMemory (op.getSrc ().getType (), op.getType ()));
280
119
281
120
// Try to use swizzling to implement the conversion
282
- // HACK Remove once AMD tests pass for the swizzling path
283
- if (targetInfo.isCuda () && succeeded (transferWithinBlockSwizzling (
284
- op, adaptor.getSrc (), rewriter))) {
121
+ if (succeeded (transferWithinBlockSwizzling (op, adaptor.getSrc (), targetInfo,
122
+ getTypeConverter (), rewriter))) {
285
123
return success ();
286
124
}
287
125
0 commit comments