Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,58 @@ struct LowerCoalescedGatherDMAPattern final
}
SmallVector<TransferSegment> segments = std::move(*segmentsOrFailure);

// OOB padding requires fat_raw_buffer for hardware OOB clamping.
if (std::optional<ArrayAttr> inBounds = dmaOp.getInBounds()) {
auto srcType = cast<MemRefType>(source.getType());
if (!hasAMDGPUFatRawBufferAddressSpace(srcType)) {
for (Attribute attr : *inBounds) {
if (!cast<BoolAttr>(attr).getValue()) {
return rewriter.notifyMatchFailure(
dmaOp, "in_bounds with OOB dimensions requires "
"fat_raw_buffer address space on source");
}
}
}

// For non-outermost dims with OOB (in_bounds=false), the vector read
// must not cross row boundaries. Each lane reads `elementsPerLane`
// contiguous elements from the source buffer. If the source dim size is
// not a multiple of elementsPerLane, a vector read near the end of a row
// will wrap into the next row instead of returning zeros.
// Example: source 64x62xf32, dest 64x64xf32, vector<4xf32>:
// Lane at [0, 60] reads 4 elements at flat offsets 60..63.
// Offset 62 wraps to [1, 0] instead of returning 0.
ArrayRef<int64_t> sourceShape = srcType.getShape();
for (int64_t dim = 1; dim < srcType.getRank(); ++dim) {
if (dim >= static_cast<int64_t>(inBounds->size())) {
break;
}
bool dimInBounds = cast<BoolAttr>((*inBounds)[dim]).getValue();
if (dimInBounds) {
continue;
}
// This non-outermost dim has padding. Check that source dim size is
// a multiple of elementsPerLane for every segment to prevent row
// crossing.
if (ShapedType::isDynamic(sourceShape[dim])) {
return rewriter.notifyMatchFailure(
dmaOp, "non-outermost OOB dim " + Twine(dim) +
" has dynamic source size; cannot verify vector "
"reads do not cross row boundaries");
}
for (const TransferSegment &segment : segments) {
if (sourceShape[dim] % segment.elementsPerLane != 0) {
return rewriter.notifyMatchFailure(
dmaOp, "non-outermost OOB dim " + Twine(dim) +
" has source size " + Twine(sourceShape[dim]) +
" not divisible by elementsPerLane " +
Twine(segment.elementsPerLane) +
"; vector reads would cross row boundaries");
}
}
}
}

// Set up for code generation.
rewriter.setInsertionPoint(dmaOp);
TypedValue<IndexType> laneId = dmaOp.getLane();
Expand All @@ -304,7 +356,8 @@ struct LowerCoalescedGatherDMAPattern final
}

emitTransfers(rewriter, loc, source, dest, destShape, numLinearDims,
elementType, indices, segments, segmentLaneOffsets);
elementType, indices, segments, segmentLaneOffsets,
dmaOp.getInBounds());

rewriter.eraseOp(dmaOp);
return success();
Expand Down Expand Up @@ -337,7 +390,8 @@ struct LowerCoalescedGatherDMAPattern final
Value dest, ArrayRef<int64_t> destShape,
int64_t numLinearDims, Type elementType,
OperandRange indices, ArrayRef<TransferSegment> segments,
ArrayRef<Value> segmentLaneOffsets) const {
ArrayRef<Value> segmentLaneOffsets,
std::optional<ArrayAttr> inBoundsAttr) const {
int64_t destRank = destShape.size();
int64_t numOuterDims = destRank - numLinearDims;
LDBG() << "Emitting transfers: " << numOuterDims << " outer dims, "
Expand Down Expand Up @@ -400,6 +454,55 @@ struct LowerCoalescedGatherDMAPattern final
auto [srcIndices, dstIndices] = generateGatherIndices(
rewriter, loc, srcDimOffsets, dstDimOffsets, indices);

// Raw buffer OOB clamping is 1D (linear): it returns 0 only when the
// byte offset >= total buffer size. For non-outermost dimensions,
// an OOB index wraps into the next row instead of returning 0.
// Fix: when any non-outermost source index exceeds its dimension,
// replace the outermost index with sourceShape[0] to force the
// linearized offset past the buffer end → hardware returns 0.
auto sourceType = cast<MemRefType>(source.getType());
if (inBoundsAttr && hasAMDGPUFatRawBufferAddressSpace(sourceType)) {
ArrayRef<int64_t> sourceShape = sourceType.getShape();
Value anyNonOutermostOOB = arith::ConstantOp::create(
rewriter, loc, rewriter.getBoolAttr(false));

for (int64_t dim = 1; dim < sourceType.getRank(); ++dim) {
if (dim >= static_cast<int64_t>(inBoundsAttr->size())) {
break;
}
bool dimInBounds =
cast<BoolAttr>((*inBoundsAttr)[dim]).getValue();
if (dimInBounds) {
continue;
}

Value dimSize;
if (ShapedType::isDynamic(sourceShape[dim])) {
dimSize = memref::DimOp::create(rewriter, loc, source, dim);
} else {
dimSize = arith::ConstantIndexOp::create(rewriter, loc,
sourceShape[dim]);
}

Value isOOB = arith::CmpIOp::create(rewriter, loc,
arith::CmpIPredicate::uge,
srcIndices[dim], dimSize);

anyNonOutermostOOB = arith::OrIOp::create(
rewriter, loc, anyNonOutermostOOB, isOOB);
}

Value oobOuterIdx;
if (ShapedType::isDynamic(sourceShape[0])) {
oobOuterIdx = memref::DimOp::create(rewriter, loc, source, 0);
} else {
oobOuterIdx =
arith::ConstantIndexOp::create(rewriter, loc, sourceShape[0]);
}
srcIndices[0] = arith::SelectOp::create(
rewriter, loc, anyNonOutermostOOB, oobOuterIdx, srcIndices[0]);
}

amdgpu::GatherToLDSOp::create(rewriter, loc, source, srcIndices, dest,
dstIndices,
TypeAttr::get(transferType));
Expand Down Expand Up @@ -438,18 +541,20 @@ struct AMDGPULowerCoalescedDMAToGatherLDSPass final

walkAndApplyPatterns(funcOp, std::move(patterns));

#ifndef NDEBUG
// Verify all CoalescedGatherDMAOps were lowered. Currently, we require all
// ops to be successfully lowered. In the future, a fallback lowering path
// (e.g., using global_load) could handle ops that don't match the pattern.
WalkResult result = funcOp.walk([&](IREE::GPU::CoalescedGatherDMAOp op) {
op.emitOpError("failed to lower coalesced_gather_dma op");
op.emitOpError(
"failed to lower to gather_to_lds; possible causes: source "
"lacks fat_raw_buffer address space for OOB padding, destination "
"is not contiguous, or element sizes are incompatible with "
"dma_sizes");
return WalkResult::interrupt();
});
if (result.wasInterrupted()) {
return signalPassFailure();
}
#endif // NDEBUG
}
};
} // namespace
Expand Down
Loading
Loading