Skip to content

Commit 9dd135e

Browse files
committed
[LinalgToXeGPU] Support squeezable any-D memrefs
Signed-off-by: dchigarev <[email protected]>
1 parent 238706b commit 9dd135e

File tree

3 files changed

+188
-55
lines changed

3 files changed

+188
-55
lines changed

include/gc/Transforms/Utils/ValueUtils.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,27 @@ Value flattenMemref(PatternRewriter &rewriter, Location loc, Value srcMemref);
5353
// Return true if the memref has shared memory space.
5454
bool hasSharedMemSpace(mlir::Value memref);
5555

56+
// Go through all parent 'memref.subview' ops for the given `memref`
57+
// and return the folded offsets of all subviews and the root memref.
58+
std::tuple<SmallVector<Value>, Value>
59+
computeSubviewOffsets(PatternRewriter &rewriter, Location loc, Value memref);
60+
61+
// Return the strides of the memref
62+
SmallVector<OpFoldResult> getMemrefStrides(PatternRewriter &rewriter,
63+
Location loc, Value memref);
64+
65+
// Squeeze the leading dimensions of a given memref up to 'maxDims'.
66+
FailureOr<Value> squeezeMemref(PatternRewriter &rewriter, Location loc,
67+
Value memref, size_t maxDims = 2);
68+
69+
// Squeeze the leading dimensions of memref operands of a given 'linalgOp'.
70+
LogicalResult maybeSqueezeDims(PatternRewriter &rewriter,
71+
linalg::LinalgOp linalgOp, size_t maxDims = 2);
72+
73+
// Return if a memref with the given shape can be squeezed to the shape of
74+
// 'maxDims'. Only leading dimensions are considered squeezable.
75+
bool canSqueezeDims(llvm::ArrayRef<int64_t> shape, size_t maxDims = 2);
76+
5677
} // namespace utils
5778
} // namespace mlir
5879

lib/gc/Transforms/GPU/LinalgToXeGPU.cpp

Lines changed: 37 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -62,28 +62,6 @@ static Value createFullMask(PatternRewriter &rewriter, Location loc,
6262
return res.getResult();
6363
}
6464

65-
// Extracts the offsets from a subview operation as values.
66-
// The differense from mlir::getMixedOffsets is that this function
67-
// returns the offsets as mlir::Value that can already be used as an argument
68-
// for other mlir::Operations.
69-
static SmallVector<Value> extractOffsetsAsValues(PatternRewriter &rewriter,
70-
Location loc,
71-
memref::SubViewOp subview) {
72-
SmallVector<Value> offsetValues;
73-
auto staticOffsets = subview.getStaticOffsets();
74-
auto dynamicOffsets = subview.getOffsets();
75-
size_t dynIdx = 0;
76-
for (size_t i = 0; i < staticOffsets.size(); i++) {
77-
if (staticOffsets[i] == ShapedType::kDynamic)
78-
offsetValues.push_back(dynamicOffsets[dynIdx++]);
79-
else
80-
offsetValues.push_back(
81-
rewriter.create<arith::ConstantIndexOp>(loc, staticOffsets[i]));
82-
}
83-
84-
return offsetValues;
85-
}
86-
8765
// Max number of elements to load/store from SLM
8866
constexpr int64_t maxSLMTileSize = 32;
8967

@@ -214,7 +192,8 @@ static LogicalResult isValidMemrefOperand(linalg::LinalgOp linalgOp,
214192
linalgOp, "Expect memref operand for XeGPU lowering");
215193
}
216194

217-
if (type.getShape().size() > maxDims) {
195+
if (type.getShape().size() > maxDims &&
196+
!utils::canSqueezeDims(type.getShape(), maxDims)) {
218197
return rewriter.notifyMatchFailure(
219198
linalgOp, "Too high dimensionality for XeGPU operations");
220199
}
@@ -856,43 +835,31 @@ static SmallVector<Value> createSLMDescTiles(PatternRewriter &rewriter,
856835
auto srcType = cast<MemRefType>(src.getType());
857836
assert(srcType.getRank() == 2 && "Expected a 2D memref");
858837

859-
SmallVector<int64_t> memrefStrides;
860-
Value blockOffset;
861-
862838
// 'imex::ConvertGPUXToSPIRVPass' doesn't allow 'memref.subview' ops in the
863839
// GPU kernel. We have to merge the subview offsets into the descriptor
864840
// offset.
865-
if (auto subView = dyn_cast<memref::SubViewOp>(src.getDefiningOp())) {
866-
auto offsets = extractOffsetsAsValues(rewriter, loc, subView);
867-
assert(offsets.size() == 2 && "Expected 2D subview offsets");
868-
869-
auto xIntOffs = offsets[0];
870-
auto yIntOffs = offsets[1];
871-
872-
// compute 'blockOffset' (beginning of the subview block in the original
873-
// flat memref)
874-
auto rowStride =
875-
cast<MemRefType>(subView.getOperand(0).getType()).getShape()[1];
876-
auto rowStrideValue =
877-
rewriter.create<arith::ConstantIndexOp>(loc, rowStride);
878-
879-
auto rowBlockOffset =
880-
rewriter.create<arith::MulIOp>(loc, xIntOffs, rowStrideValue)
881-
.getResult();
882-
blockOffset = rewriter.create<arith::AddIOp>(loc, rowBlockOffset, yIntOffs)
883-
.getResult();
841+
auto [offsets, rootMemref] = utils::computeSubviewOffsets(rewriter, loc, src);
842+
auto rootStridesFold = utils::getMemrefStrides(rewriter, loc, rootMemref);
843+
auto rootStrides =
844+
getValueOrCreateConstantIndexOp(rewriter, loc, rootStridesFold);
884845

885-
memrefStrides = {rowStride, 1};
886-
src = subView.getOperand(0);
887-
} else {
888-
// If the source is not a subview, then the blockOffset is 0
889-
blockOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
890-
memrefStrides = {srcType.getShape()[1], 1};
846+
assert(rootStrides.size() == offsets.size() &&
847+
"Expected same number of strides and offsets");
848+
849+
// blockOffset = sum(rootStrides[i] * offsets[i])
850+
Value blockOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
851+
for (size_t i = 0; i < rootStrides.size(); i++) {
852+
auto mul = rewriter.create<arith::MulIOp>(loc, rootStrides[i], offsets[i]);
853+
blockOffset = rewriter.create<arith::AddIOp>(loc, blockOffset, mul);
891854
}
892855

893-
// Scatter descriptors only work with 1D memrefs
894-
src = utils::flattenMemref(rewriter, loc, src);
856+
auto memrefStridesFold = utils::getMemrefStrides(rewriter, loc, src);
857+
auto [memrefStrides, memrefStridesDynamic] =
858+
decomposeMixedValues(memrefStridesFold);
859+
assert(memrefStridesDynamic.size() == 0 &&
860+
"Expected all values to be resolved");
895861

862+
src = utils::flattenMemref(rewriter, loc, rootMemref);
896863
return createScatterDescriptorTiles(
897864
rewriter, loc, /*flatMemref=*/src, /*loadShape2D=*/loadShape,
898865
/*tileSize2D=*/descTile, /*memrefStrides=*/memrefStrides,
@@ -1839,6 +1806,11 @@ struct ConvertGemmLikeToXeGPU : public OpRewritePattern<LinalgOpTy> {
18391806
if (failed(isOutputValid))
18401807
return isOutputValid;
18411808

1809+
if (failed(mlir::utils::maybeSqueezeDims(rewriter, gemmLikeOp))) {
1810+
return rewriter.notifyMatchFailure(
1811+
gemmLikeOp, "Failed to squeeze dimensions of GEMM-like operation");
1812+
}
1813+
18421814
// Ensure that reduction dimension tiling also works for smaller
18431815
// workloads.
18441816
auto aType = cast<ShapedType>(gemmLikeOp.getDpsInputs()[0].getType());
@@ -1894,6 +1866,12 @@ struct ConvertNamedEltwiseToXeGPU : public OpRewritePattern<LinalgOpTy> {
18941866
if (failed(isOutputValid))
18951867
return isOutputValid;
18961868

1869+
if (failed(utils::maybeSqueezeDims(rewriter, eltwiseOp))) {
1870+
return rewriter.notifyMatchFailure(
1871+
eltwiseOp,
1872+
"Could not squeeze dimensions of the elementwise operation");
1873+
}
1874+
18971875
return createEltwiseKernel(eltwiseOp, rewriter);
18981876
}
18991877

@@ -1988,6 +1966,12 @@ struct ConvertMemoryFillToXeGPU : public OpRewritePattern<LinalgOpTy> {
19881966
if (failed(isOutputValid))
19891967
return isOutputValid;
19901968

1969+
if (failed(utils::maybeSqueezeDims(rewriter, linalgOp))) {
1970+
return rewriter.notifyMatchFailure(
1971+
linalgOp,
1972+
"Could not squeeze dimensions of the memory fill operation");
1973+
}
1974+
19911975
return createMemoryFillKernel(linalgOp, rewriter);
19921976
}
19931977

lib/gc/Transforms/Utils/ValueUtils.cpp

Lines changed: 130 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include <numeric>
10+
11+
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
912
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1013
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1114
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1215
#include "mlir/Dialect/MemRef/IR/MemRef.h"
16+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1317
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1418
#include "mlir/IR/Attributes.h"
1519
#include "mlir/IR/Matchers.h"
@@ -155,9 +159,10 @@ Value flattenMemref(PatternRewriter &rewriter, Location loc, Value srcMemref) {
155159
auto srcType = cast<MemRefType>(srcMemref.getType());
156160

157161
assert(srcType && "Expected a memref type");
158-
assert(srcType.getRank() == 2 && "Expected a 2D memref");
159162

160-
int64_t flatSize = srcType.getShape()[0] * srcType.getShape()[1];
163+
auto shapeNd = srcType.getShape();
164+
int64_t flatSize =
165+
std::accumulate(shapeNd.begin(), shapeNd.end(), 1, std::multiplies<>());
161166

162167
Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
163168
Value size = rewriter.create<arith::ConstantIndexOp>(loc, flatSize);
@@ -193,5 +198,128 @@ bool hasSharedMemSpace(mlir::Value memref) {
193198
return false;
194199
}
195200

201+
std::tuple<SmallVector<Value>, Value>
202+
computeSubviewOffsets(PatternRewriter &rewriter, Location loc, Value memref) {
203+
auto fillVal = rewriter.create<arith::ConstantIndexOp>(loc, 0);
204+
auto origShape = dyn_cast<MemRefType>(memref.getType()).getShape();
205+
206+
SmallVector<Value> resolvedOffsets(origShape.size(), fillVal);
207+
208+
while (auto subViewOp = memref.getDefiningOp<memref::SubViewOp>()) {
209+
auto currentOffsets = getAsOpFoldResult(resolvedOffsets);
210+
resolvedOffsets.clear();
211+
212+
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
213+
rewriter, memref.getLoc(), subViewOp.getMixedOffsets(),
214+
subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), currentOffsets,
215+
resolvedOffsets);
216+
memref = subViewOp.getOperand(0);
217+
}
218+
219+
return std::make_tuple(resolvedOffsets, memref);
220+
}
221+
222+
SmallVector<OpFoldResult> getMemrefStrides(PatternRewriter &rewriter,
223+
Location loc, Value memref) {
224+
auto type = dyn_cast<MemRefType>(memref.getType());
225+
226+
auto stridedLayout = dyn_cast<StridedLayoutAttr>(type.getLayout());
227+
if (stridedLayout) {
228+
auto strides = stridedLayout.getStrides();
229+
return getMixedValues(strides, {}, rewriter);
230+
}
231+
232+
auto sizes = getMixedValues(type.getShape(), {}, rewriter);
233+
auto strides = memref::computeStridesIRBlock(loc, rewriter, sizes);
234+
return strides;
235+
}
236+
237+
FailureOr<Value> squeezeMemref(PatternRewriter &rewriter, Location loc,
238+
Value memref, size_t maxDims = 2) {
239+
auto type = dyn_cast<MemRefType>(memref.getType());
240+
auto shape = type.getShape();
241+
242+
if (shape.size() <= maxDims)
243+
return memref;
244+
245+
for (size_t i = 0; i < shape.size() - maxDims; i++)
246+
if (shape[i] != 1)
247+
return failure();
248+
249+
auto offsets =
250+
getMixedValues(SmallVector<int64_t>(shape.size(), 0), {}, rewriter);
251+
auto sizes = getMixedValues(shape, {}, rewriter);
252+
auto staticStrides = utils::getStaticStrides(memref).value();
253+
auto strides =
254+
getMixedValues(SmallVector<int64_t>(shape.size(), 1), {}, rewriter);
255+
256+
SmallVector<int64_t> newShape(shape.begin() + shape.size() - maxDims,
257+
shape.end());
258+
SmallVector<int64_t> newStrides(
259+
staticStrides.begin() + shape.size() - maxDims, staticStrides.end());
260+
261+
int64_t newOffset = 0;
262+
if (auto memrefLayout = dyn_cast<StridedLayoutAttr>(type.getLayout()))
263+
newOffset = memrefLayout.getOffset();
264+
265+
auto newLayout = StridedLayoutAttr::get(
266+
rewriter.getContext(), /*offset=*/newOffset, /*strides=*/newStrides);
267+
MemRefType newMemRefType = MemRefType::get(newShape, type.getElementType(),
268+
newLayout, type.getMemorySpace());
269+
270+
auto squeezedSubview =
271+
rewriter
272+
.create<memref::SubViewOp>(loc, newMemRefType, memref, offsets, sizes,
273+
strides)
274+
.getResult();
275+
return squeezedSubview;
276+
}
277+
278+
LogicalResult maybeSqueezeDims(PatternRewriter &rewriter,
279+
linalg::LinalgOp linalgOp, size_t maxDims) {
280+
SmallVector<std::pair<size_t, Value>> newOperands;
281+
auto operands = linalgOp->getOperands();
282+
auto loc = linalgOp.getLoc();
283+
284+
for (size_t i = 0; i < operands.size(); i++) {
285+
auto operand = operands[i];
286+
auto type = dyn_cast<MemRefType>(operand.getType());
287+
if (!type) {
288+
// maybe should 'continue' here instead and skip non-memref operands?
289+
// TODO: replace this with 'continue' if such case would appear someday
290+
return rewriter.notifyMatchFailure(
291+
linalgOp, "Expect memref operand for XeGPU lowering");
292+
}
293+
294+
if (type.getShape().size() <= maxDims)
295+
continue;
296+
297+
auto res = squeezeMemref(rewriter, loc, operand, maxDims);
298+
if (failed(res)) {
299+
return rewriter.notifyMatchFailure(
300+
linalgOp, "Can't squeeze memref to the desired number of dimensions");
301+
}
302+
303+
auto flatSubview = res.value();
304+
newOperands.emplace_back(i, flatSubview);
305+
}
306+
307+
for (auto [i, operand] : newOperands)
308+
linalgOp->setOperand(i, operand);
309+
310+
return success();
311+
}
312+
313+
bool canSqueezeDims(llvm::ArrayRef<int64_t> shape, size_t maxDims) {
314+
if (shape.size() <= maxDims)
315+
return true;
316+
317+
for (size_t i = 0; i < shape.size() - maxDims; i++)
318+
if (shape[i] != 1)
319+
return false;
320+
321+
return true;
322+
}
323+
196324
} // namespace utils
197325
} // namespace mlir

0 commit comments

Comments
 (0)