@@ -282,111 +282,79 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
282282 const auto &shape = op.getType ().getShape ();
283283 auto srcTy = op.getSrc ().getType ();
284284 auto dstTy = op.getType ();
285- std::optional<LinearLayout> srcLayout =
286- toLinearLayout (shape, srcTy.getEncoding ());
287- std::optional<LinearLayout> dstLayout =
288- toLinearLayout (shape, dstTy.getEncoding ());
289- if (!srcLayout.has_value () || !dstLayout.has_value ()) {
290- return failure ();
291- }
292285
293- // There are four cases to handle.
294- //
295- // 1. Transfer between values in the same thread, in which case we simply
296- // reorder the elements of adaptor.getSrc().
297- // 2. Transfer between values in the same warp, in which case we try to
298- // move values using warp shuffles, though if the pattern is complicated
299- // enough we may fall back to using shared memory (case 3).
300- // 3. Transfer between values in the same CTA, in which case we move values
301- // through shared memory.
302- // 4. Transfer between values in different CTAs, in which case we move
303- // values through distributed shared memory.
304- //
305- // We can tell which case we're in by examining `conversion`.
306- // For example, if the block -> block mapping is an identity layout: {1, 2,
307- // 4, ...}, then there's no movement between data in different CTAs, and we
308- // know we're not in case 4.
309- if (cvtReordersRegisters (srcTy, dstTy)) { // Case 1.
310- return transferWithinThread (op, *srcLayout, *dstLayout, adaptor,
311- rewriter);
286+ auto conversion = minimalCvtLayout (srcTy, dstTy);
287+ if (!conversion.has_value ()) {
288+ return rewriter.notifyMatchFailure (
289+ op, " NYI. srcTy and/or dstTy don't implement LLs yet" );
312290 }
313291
314- if (cvtNeedsWarpShuffle (srcTy, dstTy)) { // Case 2.
315- return transferWithinLane (op, *srcLayout, *dstLayout, adaptor, rewriter);
292+ assert (to_vector (conversion->getInDimNames ()) ==
293+ to_vector (conversion->getOutDimNames ()));
294+ auto dims = conversion->getInDimNames ();
295+ if (llvm::is_contained (dims, str_attr (" block" ))) {
296+ // Case 1: Transfer between values in different CTAs.
297+ // This requires moving values through distributed shared memory.
298+ return rewriter.notifyMatchFailure (
299+ op, " NYI: Transfer between different CTAs" );
300+ } else if (llvm::is_contained (dims, str_attr (" warp" ))) {
301+ // Case 2: Transfer between values in the same CTA, in which case we move
302+ // values through shared memory.
303+ LinearLayout srcLayout =
304+ *toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
305+ LinearLayout dstLayout =
306+ *toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
307+ return transferWithinBlock (op, srcLayout, dstLayout, adaptor, rewriter);
308+ } else if (llvm::is_contained (dims, str_attr (" lane" ))) {
309+ // Case 3. Transfer between values in the same warp, in which case we try
310+ // to move values using warp shuffles, though if the pattern is
311+ // complicated enough we may fall back to using shared memory
312+ // TODO(Keren): implement warp shuffle instead of using the general
313+ // approach that uses shared memory
314+ LinearLayout srcLayout =
315+ *toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
316+ LinearLayout dstLayout =
317+ *toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
318+ return transferWithinBlock (op, srcLayout, dstLayout, adaptor, rewriter);
319+ } else if (llvm::is_contained (dims, str_attr (" register" ))) {
320+ // Case 4. Transfer between values in the same thread, in which case we
321+ // simply reorder the elements of adaptor.getSrc().
322+ return transferWithinThread (op, *conversion, adaptor, rewriter);
323+ } else {
324+ // The two layouts are equivalent. We should probably remove these in
325+ // RemoveLayoutConversion.
326+ rewriter.replaceOp (op, adaptor.getSrc ());
327+ return success ();
316328 }
317-
318- return transferWithinBlockOrGroup (op, *srcLayout, *dstLayout, adaptor,
319- rewriter); // Case 3 and 4
320329 }
321330
322331 LogicalResult
323- transferWithinThread (ConvertLayoutOp op, const LinearLayout &srcLayout ,
324- const LinearLayout &dstLayout, OpAdaptor adaptor,
332+ transferWithinThread (ConvertLayoutOp op, const LinearLayout &conversion ,
333+ OpAdaptor adaptor,
325334 ConversionPatternRewriter &rewriter) const {
326335 MLIRContext *ctx = op.getContext ();
327336 auto loc = op.getLoc ();
328337 StringAttr kRegister = str_attr (" register" );
329- StringAttr kLane = str_attr (" lane" );
330- StringAttr kWarp = str_attr (" warp" );
331- StringAttr kBlock = str_attr (" block" );
332-
333- // There are three possible cases:
334- //
335- // 1. `srcLayout` has the same number of registers as `dstLayout`.
336- // 2. `srcLayout` has fewer registers than `dstLayout`.
337- // 3. `srcLayout` has more registers than `dstLayout`.
338- //
339- // In the second case `srcLayout . dstLayout^-1` is not surjective
340- // because not all destination registers are covered.
341- // Since the goal is to cover all of the destination
342- // registers, we can instead use `dstLayout . srcLayout^-1`.
343- LinearLayout conversion = dstLayout.invertAndCompose (srcLayout);
344- auto dstToSrc = conversion.divideRight (
345- LinearLayout::identity1D (conversion.getInDimSize (kLane ), kLane , kLane ) *
346- LinearLayout::identity1D (conversion.getInDimSize (kWarp ), kWarp , kWarp ) *
347- LinearLayout::identity1D (conversion.getInDimSize (kBlock ), kBlock ,
348- kBlock ));
349-
350338 assert (!cvtNeedsSharedMemory (op.getSrc ().getType (), op.getType ()));
351- assert (ArrayRef (to_vector (dstToSrc->getInDimNames ())) ==
352- ArrayRef{kRegister });
353- assert (ArrayRef (to_vector (dstToSrc->getOutDimNames ())) ==
354- ArrayRef{kRegister });
355339
356340 auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
357341 SmallVector<Value> outVals;
358- outVals.resize (dstToSrc-> getInDimSize (kRegister ));
359- for (int i = 0 ; i < dstToSrc-> getInDimSize (kRegister ); i++) {
360- auto srcIdx = dstToSrc-> apply ({{kRegister , i}});
361- outVals[i] = inVals[srcIdx. begin ()-> second ];
342+ outVals.resize (conversion. getInDimSize (kRegister ));
343+ for (int i = 0 ; i < conversion. getInDimSize (kRegister ); i++) {
344+ auto srcIdx = conversion. apply ({{kRegister , i}}). begin ()-> second ;
345+ outVals[i] = inVals[srcIdx];
362346 }
363347 Value result = packLLElements (loc, getTypeConverter (), outVals, rewriter,
364348 op.getType ());
365349 rewriter.replaceOp (op, result);
366350 return success ();
367351 }
368352
369- LogicalResult transferWithinLane (ConvertLayoutOp op,
370- const LinearLayout &srcLayout,
371- const LinearLayout &dstLayout,
372- OpAdaptor adaptor,
373- ConversionPatternRewriter &rewriter) const {
374- // TODO(Keren): implement warp shuffle instead of using the general approach
375- // that uses shared memory
376- return transferWithinBlockOrGroup (op, srcLayout, dstLayout, adaptor,
377- rewriter);
378- }
379-
380- LogicalResult
381- transferWithinBlockOrGroup (ConvertLayoutOp op, const LinearLayout &srcLayout,
382- const LinearLayout &dstLayout, OpAdaptor adaptor,
383- ConversionPatternRewriter &rewriter) const {
384- LinearLayout conversion = srcLayout.invertAndCompose (dstLayout);
385-
386- // TODO(Keren): LLs support cross-CTA conversions, this function does not
387- if (isCrossCTAConversion (conversion))
388- return failure ();
389-
353+ LogicalResult transferWithinBlock (ConvertLayoutOp op,
354+ const LinearLayout &srcLayout,
355+ const LinearLayout &dstLayout,
356+ OpAdaptor adaptor,
357+ ConversionPatternRewriter &rewriter) const {
390358 MLIRContext *ctx = op.getContext ();
391359 auto loc = op.getLoc ();
392360 auto srcTy = op.getSrc ().getType ();
@@ -445,11 +413,14 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
445413 }
446414 }
447415
416+ // Pretty sure this is the identity function ATM
417+ // It'd be better to simply call `quotient({kBlock})` and
418+ // remove kBlock from transferWithinBlockImpl
448419 auto srcLayoutWithinBlock = getLayoutWithinBlock (srcLayout);
449420 auto dstLayoutWithinBlock = getLayoutWithinBlock (dstLayout);
450421 SmallVector<Value> outVals =
451- transferWithinBlock (inVals, op, srcLayoutWithinBlock,
452- dstLayoutWithinBlock, adaptor, rewriter);
422+ transferWithinBlockImpl (inVals, op, srcLayoutWithinBlock,
423+ dstLayoutWithinBlock, adaptor, rewriter);
453424
454425 // Unmunge output values
455426 for (const auto &it : llvm::enumerate (outVals)) {
@@ -467,10 +438,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
467438 }
468439
469440 SmallVector<Value>
470- transferWithinBlock (ArrayRef<Value> inVals, ConvertLayoutOp op,
471- const LinearLayout &srcLayout,
472- const LinearLayout &dstLayout, OpAdaptor adaptor,
473- ConversionPatternRewriter &rewriter) const {
441+ transferWithinBlockImpl (ArrayRef<Value> inVals, ConvertLayoutOp op,
442+ const LinearLayout &srcLayout,
443+ const LinearLayout &dstLayout, OpAdaptor adaptor,
444+ ConversionPatternRewriter &rewriter) const {
474445 MLIRContext *ctx = op.getContext ();
475446 auto loc = op.getLoc ();
476447
0 commit comments