1010#include " iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
1111#include " iree/compiler/Codegen/Utils/GPUUtils.h"
1212#include " iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
13+ #include " iree/compiler/Utils/Indexing.h"
1314#include " iree/compiler/Utils/Permutation.h"
1415#include " llvm/ADT/ArrayRef.h"
16+ #include " llvm/ADT/STLExtras.h"
1517#include " llvm/ADT/SmallVector.h"
1618#include " llvm/Support/FormatVariadic.h"
1719#include " mlir/Dialect/Affine/IR/AffineOps.h"
@@ -422,9 +424,70 @@ struct DistributeTransferWrite final
422424 using OpDistributionPattern::OpDistributionPattern;
423425
424426 DistributeTransferWrite (MLIRContext *context, Value threadId,
425- int64_t subgroupSize)
427+ int64_t subgroupSize, ArrayRef< int64_t > workgroupSize )
426428 : OpDistributionPattern(context), threadId(threadId),
427- subgroupSize (subgroupSize) {}
429+ subgroupSize (subgroupSize) {
430+
431+ // The number of threads in the workgroup is the product of the dimensions
432+ // of workgroupSize, unless workgroupSize is empty.
433+ if (!workgroupSize.empty ()) {
434+ numThreadsInWorkgroup = llvm::product_of (workgroupSize);
435+ }
436+ }
437+
438+ // / Compute a boolean in SIMT semantics that is true for the first virtual
439+ // / lane(thread) id (vtid) and virtual subgroup id (vsid) carrying broadcasted
440+ // / data.
441+ // /
442+ // / We do this by computing a basis for vtid and vsid computation, and adding
443+ // / a check for basis elements that are not used (i.e. they are duplicated)
444+ // / to be zero.
445+ FailureOr<Value> getNoOverlapCondition (OpBuilder &b, Location loc,
446+ NestedLayoutAttr layout) const {
447+ ArrayRef<int64_t > threadTile = layout.getThreadTile ();
448+ ArrayRef<int64_t > threadStrides = layout.getThreadStrides ();
449+ ArrayRef<int64_t > subgroupTile = layout.getSubgroupTile ();
450+ // Multiply the subgroup strides by subgroup_size to reflect thread id
451+ // relative strides.
452+ auto subgroupStrides =
453+ llvm::map_to_vector (layout.getSubgroupStrides (),
454+ [&](int64_t x) { return x * subgroupSize; });
455+ auto concatTiles =
456+ llvm::to_vector (llvm::concat<const int64_t >(subgroupTile, threadTile));
457+ auto concatStrides = llvm::to_vector (
458+ llvm::concat<const int64_t >(subgroupStrides, threadStrides));
459+ SmallVector<int64_t > basis;
460+ SmallVector<size_t > dimToResult;
461+ if (failed (basisFromSizesStrides (concatTiles, concatStrides, basis,
462+ dimToResult))) {
463+ return failure ();
464+ }
465+ // Make the outer bound numThreadsInWorkgroup / prod(basis) to remove
466+ // redundant checks.
467+ if (numThreadsInWorkgroup.has_value ()) {
468+ int64_t outerBound =
469+ numThreadsInWorkgroup.value () / llvm::product_of (basis);
470+ basis.insert (basis.begin (), outerBound);
471+ }
472+ // Create a delinearize operation and check that all results not present in
473+ // dimToResult are 0.
474+ SmallVector<Value> delinearized;
475+ b.createOrFold <affine::AffineDelinearizeIndexOp>(
476+ delinearized, loc, threadId, basis,
477+ /* hasOuterbound=*/ numThreadsInWorkgroup.has_value ());
478+ // Get all results which are not in dimToResult and check they are 0.
479+ Value condition = arith::ConstantOp::create (b, loc, b.getBoolAttr (true ));
480+ for (auto [idx, result] : llvm::enumerate (delinearized)) {
481+ if (llvm::is_contained (dimToResult, idx)) {
482+ continue ;
483+ }
484+ Value isZero = b.createOrFold <arith::CmpIOp>(
485+ loc, arith::CmpIPredicate::eq, result,
486+ arith::ConstantIndexOp::create (b, loc, 0 ));
487+ condition = b.createOrFold <arith::AndIOp>(loc, condition, isZero);
488+ }
489+ return condition;
490+ }
428491
429492 LogicalResult matchAndRewrite (vector::TransferWriteOp writeOp,
430493 DistributionSignature &signature,
@@ -456,7 +519,6 @@ struct DistributeTransferWrite final
456519 SmallVector<int64_t > distShape = vectorLayout.getDistributedShape ();
457520 SmallVector<int64_t > tileShape = getElementVectorTileShape (vectorLayout);
458521 int64_t rank = vectorLayout.getRank ();
459-
460522 SmallVector<Value> warpIndices, threadIndices;
461523 if (failed (populateWarpAndThreadIndices (rewriter, threadId, subgroupSize,
462524 vectorLayout, warpIndices,
@@ -465,6 +527,18 @@ struct DistributeTransferWrite final
465527 writeOp, " warp or thread tiles have overlapping strides" );
466528 }
467529
530+ // If the distribution results in threads writing to the same address, guard
531+ // with an scf.if to ensure only one thread writes per duplication group.
532+ Location loc = writeOp.getLoc ();
533+ FailureOr<Value> doWrite =
534+ getNoOverlapCondition (rewriter, loc, vectorLayout);
535+ if (failed (doWrite)) {
536+ return rewriter.notifyMatchFailure (
537+ writeOp, " failed to compute no-overlap condition" );
538+ }
539+ auto ifOp = scf::IfOp::create (rewriter, loc, doWrite.value ());
540+ rewriter.setInsertionPoint (ifOp.thenYield ());
541+
468542 Value distributedVector =
469543 getDistributed (rewriter, writeOp.getValueToStore (), vectorLayout);
470544
@@ -485,7 +559,6 @@ struct DistributeTransferWrite final
485559 SmallVector<Value> slicedIndices = getTransferIndicesFromNestedLayout (
486560 rewriter, indices, offsets, vectorLayout, permMap, warpIndices,
487561 threadIndices);
488-
489562 // Extract the "element vector" from the inner most dimensions. All outer
490563 // dimensions are either unrolled or distributed such that this is a
491564 // contiguous slice.
@@ -516,6 +589,7 @@ struct DistributeTransferWrite final
516589
517590 Value threadId;
518591 int64_t subgroupSize;
592+ std::optional<int64_t > numThreadsInWorkgroup = std::nullopt ;
519593};
520594
521595// / Pattern to distribute `vector.transfer_gather` ops with nested layouts.
@@ -2127,13 +2201,14 @@ struct DistributeConstantMask final
21272201
21282202} // namespace
21292203
2130- void populateGPUDistributeNestedLayoutAttrPatterns (RewritePatternSet &patterns,
2131- Value threadId,
2132- int64_t subgroupSize,
2133- int64_t maxBitsPerShuffle) {
2134- patterns.add <DistributeTransferRead, DistributeTransferWrite,
2135- DistributeTransferGather, DistributeMapScatter>(
2136- patterns.getContext (), threadId, subgroupSize);
2204+ void populateGPUDistributeNestedLayoutAttrPatterns (
2205+ RewritePatternSet &patterns, Value threadId, int64_t subgroupSize,
2206+ ArrayRef<int64_t > workgroupSize, int64_t maxBitsPerShuffle) {
2207+ patterns.add <DistributeTransferRead, DistributeTransferGather,
2208+ DistributeMapScatter>(patterns.getContext (), threadId,
2209+ subgroupSize);
2210+ patterns.add <DistributeTransferWrite>(patterns.getContext (), threadId,
2211+ subgroupSize, workgroupSize);
21372212 patterns.add <DistributeBroadcast, DistributeTranspose>(patterns.getContext ());
21382213 patterns.add <DistributeMultiReduction>(patterns.getContext (), subgroupSize,
21392214 maxBitsPerShuffle);
0 commit comments