@@ -419,13 +419,93 @@ GatherLoweringHelper::GatherLoweringHelper(triton::GatherOp gatherOp)
419419 : gatherOp(gatherOp) {}
420420
421421unsigned GatherLoweringHelper::getScratchSizeInBytes () {
422- // For now, lower the gather op by writing the source tensor to shared memory.
423- // TODO(jeff): Leverage locality to avoid using scratch space when possible.
422+ // If the gather is warp-local, no scratch space is needed.
423+ if (isWarpLocal ())
424+ return 0 ;
425+
426+ // Otherwise, performing the gather will require scratch space to communicate
427+ // the source tensor across threads. For now, assume the whole source tensor
428+ // is written back to shared memory.
424429 RankedTensorType srcType = gatherOp.getSrc ().getType ();
425430 return product (srcType.getShape ()) *
426431 ceil<unsigned >(srcType.getElementTypeBitWidth (), 8 );
427432}
428433
434+ bool GatherLoweringHelper::isWarpLocal () {
435+ // The gather is warp-local if for each column along the gather axis in the
436+ // source and index tensors, all the elements are owned by the same warp.
437+ RankedTensorType srcType = gatherOp.getSrc ().getType ();
438+ RankedTensorType idxType = gatherOp.getIndices ().getType ();
439+ std::optional<LinearLayout> srcLayout =
440+ toLinearLayout (srcType.getShape (), srcType.getEncoding ());
441+ std::optional<LinearLayout> idxLayout =
442+ toLinearLayout (idxType.getShape (), idxType.getEncoding ());
443+
444+ // FIXME: If an unsupported layout was encountered, assume the gather is not
445+ // warp-local.
446+ if (!srcLayout || !idxLayout)
447+ return false ;
448+
449+ Builder b (gatherOp.getContext ());
450+ StringAttr kBlock = b.getStringAttr (" block" );
451+ StringAttr kWarp = b.getStringAttr (" warp" );
452+ StringAttr kLane = b.getStringAttr (" lane" );
453+ StringAttr kGatherDim =
454+ b.getStringAttr (" dim" + std::to_string (gatherOp.getAxis ()));
455+
456+ // The tensor layouts must be distributed layouts, where the basis matrix is a
457+ // subpermutation matrix (permutation matrix plus zeros for broadcasting).
458+ // FIXME(jeff): Check this invariant somehow.
459+ //
460+ // We want to know if all elements of a column along the gather axis are
461+ // mapped to the same set of warps, which means the gather can be performed
462+ // entirely within the warp. We need to query
463+ //
464+ // srcLayout.invert().sublayoutIsZero({kGatherDim}, {kBlock, kWarp})
465+ //
466+ // But due to broadcasting, the matrix might not be invertible. But since the
467+ // matrix is a permutation matrix (checked below), we can instead query
468+ //
469+ // srcLayout.sublayoutIsZero({kBlock, kWarp}, {kGatherDim})
470+ //
471+ // Which implies that changing the warp will not change the gather dimension.
472+ // And since there is no swizzling, this applies to all warps.
473+ if (!srcLayout->sublayoutIsZero ({kBlock , kWarp }, kGatherDim ) ||
474+ !idxLayout->sublayoutIsZero ({kBlock , kWarp }, kGatherDim ))
475+ return false ;
476+
477+ SmallVector<StringAttr> otherDims;
478+ for (unsigned dim = 0 , rank = srcType.getRank (); dim < rank; ++dim) {
479+ if (dim != gatherOp.getAxis ()) {
480+ otherDims.push_back (b.getStringAttr (" dim" + Twine (dim)));
481+ }
482+ }
483+
484+ // If the gather axis `dimN` is invariant to the warp, but the `(block, warp)`
485+ // mapping to all other dimensions must be the same for both layouts. If so,
486+ // then the warp that owns a particular index element also owns all the source
487+ // elements it could index into.
488+ if (srcLayout->sublayout ({kBlock , kWarp }, otherDims) !=
489+ idxLayout->sublayout ({kBlock , kWarp }, otherDims))
490+ return false ;
491+
492+ // The two constraints above ensure that data-movement to perform the gather
493+ // operation are contained within a warp. The subsequent constraints simplify
494+ // codegen.
495+
496+ // Require that for any given gather column, the threads mapped to the column
497+ // in the index and source tensors are the same. This means we don't need to
498+ // xor shuffle across threads before emitting index shuffles; we push warp
499+ // shuffling to layout conversions.
500+ if (srcLayout->sublayout (kLane , otherDims) !=
501+ idxLayout->sublayout (kLane , otherDims))
502+ return false ;
503+
504+ // Otherwise, the source layout has to be invertible. This primarily means
505+ // the codegen path doesn't support broadcasted source layouts.
506+ return srcLayout->isInvertible ();
507+ }
508+
429509unsigned getNumScratchElements (ArrayRef<unsigned > shape) {
430510 if (shape.empty ())
431511 return 0 ;
0 commit comments