@@ -250,10 +250,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
250
250
MLIRContext *ctx = op.getContext ();
251
251
252
252
const auto &shape = op.getType ().getShape ();
253
+ auto srcTy = op.getSrc ().getType ();
254
+ auto dstTy = op.getType ();
253
255
std::optional<LinearLayout> srcLayout =
254
- toLinearLayout (shape, op. getSrc (). getType () .getEncoding ());
256
+ toLinearLayout (shape, srcTy .getEncoding ());
255
257
std::optional<LinearLayout> dstLayout =
256
- toLinearLayout (shape, op. getType () .getEncoding ());
258
+ toLinearLayout (shape, dstTy .getEncoding ());
257
259
if (!srcLayout.has_value () || !dstLayout.has_value ()) {
258
260
return failure ();
259
261
}
@@ -270,93 +272,94 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
270
272
// 4. Transfer between values in different CTAs, in which case we move
271
273
// values through distributed shared memory.
272
274
//
273
- // We can tell which case we're in by examining `conversion`. If e.g. the
274
- // block -> block mapping is {1, 2, 4, ...} then there's no movement between
275
- // data in different CTAs and we know we're not in case 4.
276
- LinearLayout conversion = srcLayout->invertAndCompose (*dstLayout);
277
-
278
- int numLanes = conversion.getInDimSize (str_attr (" lane" ));
279
- int numWarps = conversion.getInDimSize (str_attr (" warp" ));
280
- int numBlocks = conversion.getInDimSize (str_attr (" block" ));
281
-
282
- StringAttr kLane = str_attr (" lane" );
283
- StringAttr kWarp = str_attr (" warp" );
284
- StringAttr kBlock = str_attr (" block" );
285
-
286
- // TODO(jlebar): These checks are overly-restrictive. For example, we can
287
- // transfer by shuffling registers (case 1) if and only if all of the bases
288
- // for `register` have 0s for lane, warp, and block. But the check below is
289
- // stronger than this, checking also that the choice of lane/warp/block does
290
- // not affect the permutation of registers. If we allow different
291
- // lane/warp/blocks to have different permutations, we can generalize this.
292
- if (std::optional<LinearLayout> c = conversion.divideRight (
293
- LinearLayout::identity1D (numLanes, kLane , kLane ) *
294
- LinearLayout::identity1D (numWarps, kWarp , kWarp ) *
295
- LinearLayout::identity1D (numBlocks, kBlock , kBlock ));
296
- c.has_value ()) {
297
- return transferWithinThread (*c, op, adaptor, rewriter);
275
+ // We can tell which case we're in by examining `conversion`.
276
+ // For example, if the block -> block mapping is an identity layout: {1, 2,
277
+ // 4, ...}, then there's no movement between data in different CTAs, and we
278
+ // know we're not in case 4.
279
+ if (cvtReordersRegisters (srcTy, dstTy)) { // Case 1.
280
+ return transferWithinThread (op, *srcLayout, *dstLayout, adaptor,
281
+ rewriter);
298
282
}
299
283
300
- if (std::optional<LinearLayout> c = conversion.divideRight (
301
- LinearLayout::identity1D (numWarps, kWarp , kWarp ) *
302
- LinearLayout::identity1D (numBlocks, kBlock , kBlock ));
303
- c.has_value ()) {
304
- return transferWithinLane (*c, op, adaptor, rewriter);
284
+ if (cvtNeedsWarpShuffle (srcTy, dstTy)) { // Case 2.
285
+ return transferWithinLane (op, *srcLayout, *dstLayout, adaptor, rewriter);
305
286
}
306
287
307
- return transferWithinBlockOrGroup (conversion, op, *srcLayout, *dstLayout,
308
- adaptor, rewriter);
288
+ return transferWithinBlockOrGroup (op, *srcLayout, *dstLayout, adaptor ,
289
+ rewriter); // Case 3 and 4
309
290
}
310
291
311
292
LogicalResult
312
- transferWithinThread (const LinearLayout &conversion, ConvertLayoutOp op ,
313
- OpAdaptor adaptor,
293
+ transferWithinThread (ConvertLayoutOp op, const LinearLayout &srcLayout ,
294
+ const LinearLayout &dstLayout, OpAdaptor adaptor,
314
295
ConversionPatternRewriter &rewriter) const {
315
296
MLIRContext *ctx = op.getContext ();
316
297
auto loc = op.getLoc ();
317
298
StringAttr kRegister = str_attr (" register" );
299
+ StringAttr kLane = str_attr (" lane" );
300
+ StringAttr kWarp = str_attr (" warp" );
301
+ StringAttr kBlock = str_attr (" block" );
302
+
303
+ // There are three possible cases:
304
+ //
305
+ // 1. `srcLayout` has the same number of registers as `dstLayout`.
306
+ // 2. `srcLayout` has fewer registers than `dstLayout`.
307
+ // 3. `srcLayout` has more registers than `dstLayout`.
308
+ //
309
+ // In the second case `srcLayout . dstLayout^-1` is not surjective
310
+ // because not all destination registers are covered.
311
+ // Since the goal is to cover all of the destination
312
+ // registers, we can instead use `dstLayout . srcLayout^-1`.
313
+ LinearLayout conversion = dstLayout.invertAndCompose (srcLayout);
314
+ auto dstToSrc = conversion.divideRight (
315
+ LinearLayout::identity1D (conversion.getInDimSize (kLane ), kLane , kLane ) *
316
+ LinearLayout::identity1D (conversion.getInDimSize (kWarp ), kWarp , kWarp ) *
317
+ LinearLayout::identity1D (conversion.getInDimSize (kBlock ), kBlock ,
318
+ kBlock ));
318
319
319
320
assert (!cvtNeedsSharedMemory (op.getSrc ().getType (), op.getType ()));
320
- assert (ArrayRef (to_vector (conversion. getInDimNames ())) ==
321
+ assert (ArrayRef (to_vector (dstToSrc-> getInDimNames ())) ==
321
322
ArrayRef{kRegister });
322
- assert (ArrayRef (to_vector (conversion. getOutDimNames ())) ==
323
+ assert (ArrayRef (to_vector (dstToSrc-> getOutDimNames ())) ==
323
324
ArrayRef{kRegister });
324
325
325
326
auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
326
- SmallVector<Value> outVals (conversion.getOutDimSize (kRegister ));
327
- for (int i = 0 ; i < conversion.getInDimSize (kRegister ); i++) {
328
- auto dstIdx = conversion.apply ({{kRegister , i}});
329
- outVals[dstIdx.begin ()->second ] = inVals[i];
327
+ SmallVector<Value> outVals;
328
+ outVals.resize (dstToSrc->getInDimSize (kRegister ));
329
+ for (int i = 0 ; i < dstToSrc->getInDimSize (kRegister ); i++) {
330
+ auto srcIdx = dstToSrc->apply ({{kRegister , i}});
331
+ outVals[i] = inVals[srcIdx.begin ()->second ];
330
332
}
331
333
Value result = packLLElements (loc, getTypeConverter (), outVals, rewriter,
332
334
op.getType ());
333
335
rewriter.replaceOp (op, result);
334
336
return success ();
335
337
}
336
338
337
- LogicalResult transferWithinLane (const LinearLayout &conversion,
338
- ConvertLayoutOp op, OpAdaptor adaptor,
339
+ LogicalResult transferWithinLane (ConvertLayoutOp op,
340
+ const LinearLayout &srcLayout,
341
+ const LinearLayout &dstLayout,
342
+ OpAdaptor adaptor,
339
343
ConversionPatternRewriter &rewriter) const {
340
344
// TODO(jlebar): Implement me.
341
345
return failure ();
342
346
}
343
347
344
348
LogicalResult
345
- transferWithinBlockOrGroup (const LinearLayout &conversion, ConvertLayoutOp op,
346
- const LinearLayout &srcLayout,
349
+ transferWithinBlockOrGroup (ConvertLayoutOp op, const LinearLayout &srcLayout,
347
350
const LinearLayout &dstLayout, OpAdaptor adaptor,
348
351
ConversionPatternRewriter &rewriter) const {
352
+ LinearLayout conversion = srcLayout.invertAndCompose (dstLayout);
353
+
349
354
// TODO(Keren): LLs support cross-CTA conversions, this function does not
350
355
if (isCrossCTAConversion (conversion))
351
356
return failure ();
352
357
353
358
MLIRContext *ctx = op.getContext ();
354
359
auto loc = op.getLoc ();
355
360
356
- assert (cvtNeedsSharedMemory (op.getSrc ().getType (), op.getType ()));
357
-
358
- // TODO(jlebar): For now we handle only blocked/slice -> blocked/slice
359
- // conversions. Once we have ldmatrix support in
361
+ // TODO(jlebar): For now we handle only blocked/slice ->
362
+ // blocked/slice conversions. Once we have ldmatrix support in
360
363
// load/storeDistributedToShared, we can remove this constraint.
361
364
std::function<bool (Attribute)> layoutIsOK = [&](Attribute layout) {
362
365
if (isa<BlockedEncodingAttr>(layout)) {
@@ -372,6 +375,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
372
375
return failure ();
373
376
}
374
377
378
+ assert (cvtNeedsSharedMemory (op.getSrc ().getType (), op.getType ()));
379
+
375
380
SmallVector<Value> inVals =
376
381
unpackLLElements (loc, adaptor.getSrc (), rewriter);
377
382
assert (!inVals.empty ());
0 commit comments