Skip to content

Commit 048922e

Browse files
[mlir][memref-to-spirv]: Remap Image Load Coordinates (llvm#160495)
When converting a `memref.load` from the image address space to a `spirv.ImageFetch` ensure that we correctly map the load indices to width, height and depth. The lowering currently assumes a linear image tiling, that is row-major memory layout. This allows us to support any memref layout that is a permutation of the dimensions, more complex layouts are not currently supported. Because the ordering of the dimensions in the vector passed to image fetch is the opposite to that in the memref directions a final reversal of the mapped dimensions is always required. --------- Signed-off-by: Jack Frankland <[email protected]>
1 parent 55054db commit 048922e

File tree

2 files changed

+197
-62
lines changed

2 files changed

+197
-62
lines changed

mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,35 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
699699
return success();
700700
}
701701

702+
template <typename OpAdaptor>
703+
static FailureOr<SmallVector<Value>>
704+
extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor,
705+
ConversionPatternRewriter &rewriter) {
706+
// At present we only support linear "tiling" as specified in Vulkan, this
707+
// means that texels are assumed to be laid out in memory in a row-major
708+
// order. This allows us to support any memref layout that is a permutation of
709+
// the dimensions. Future work will pass an optional image layout to the
710+
// rewrite pattern so that we can support optimized target specific tilings.
711+
SmallVector<Value> indices = adaptor.getIndices();
712+
AffineMap map = loadOp.getMemRefType().getLayout().getAffineMap();
713+
if (!map.isPermutation())
714+
return rewriter.notifyMatchFailure(
715+
loadOp,
716+
"Cannot lower memrefs with memory layout which is not a permutation");
717+
718+
// The memrefs layout determines the dimension ordering so we need to follow
719+
// the map to get the ordering of the dimensions/indices.
720+
const unsigned dimCount = map.getNumDims();
721+
SmallVector<Value, 3> coords(dimCount);
722+
for (unsigned dim = 0; dim < dimCount; ++dim)
723+
coords[map.getDimPosition(dim)] = indices[dim];
724+
725+
// We need to reverse the coordinates because the memref layout is slowest to
726+
// fastest moving and the vector coordinates for the image op is fastest to
727+
// slowest moving.
728+
return llvm::to_vector(llvm::reverse(coords));
729+
}
730+
702731
LogicalResult
703732
ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
704733
ConversionPatternRewriter &rewriter) const {
@@ -755,13 +784,17 @@ ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
755784

756785
// Build a vector of coordinates or just a scalar index if we have a 1D image.
757786
Value coords;
758-
if (memrefType.getRank() != 1) {
787+
if (memrefType.getRank() == 1) {
788+
coords = adaptor.getIndices()[0];
789+
} else {
790+
FailureOr<SmallVector<Value>> maybeCoords =
791+
extractLoadCoordsForComposite(loadOp, adaptor, rewriter);
792+
if (failed(maybeCoords))
793+
return failure();
759794
auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
760795
adaptor.getIndices().getType()[0]);
761796
coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
762-
adaptor.getIndices());
763-
} else {
764-
coords = adaptor.getIndices()[0];
797+
maybeCoords.value());
765798
}
766799

767800
// Fetch the value out of the image.

0 commit comments

Comments
 (0)