Skip to content

Commit 472ef9c

Browse files
authored
[LinalgToXeGPU] Support squeezable any-D memrefs (#414)
Signed-off-by: dchigarev <[email protected]>
1 parent 2d51c4e commit 472ef9c

File tree

4 files changed

+258
-55
lines changed

4 files changed

+258
-55
lines changed

include/gc/Transforms/Utils/ValueUtils.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,28 @@ Value flattenMemref(PatternRewriter &rewriter, Location loc, Value srcMemref);
5353
// Return true if the memref has shared memory space.
5454
bool hasSharedMemSpace(mlir::Value memref);
5555

56+
// Go through all parent 'memref.subview' ops for the given `memref`
57+
// and return the folded offsets of all subviews and the root memref.
58+
void computeSubviewOffsets(PatternRewriter &rewriter, Location loc,
59+
Value memref, SmallVector<Value> &resultOffsets,
60+
Value &resultRootMemref);
61+
62+
// Return the strides of the memref
63+
SmallVector<OpFoldResult> getMemrefStrides(PatternRewriter &rewriter,
64+
Location loc, Value memref);
65+
66+
// Squeeze the leading dimensions of a given memref up to 'maxDims'.
67+
FailureOr<Value> reduceMemrefDims(PatternRewriter &rewriter, Location loc,
68+
Value memref, size_t maxDims = 2);
69+
70+
// Squeeze the leading dimensions of memref operands of a given 'linalgOp'.
71+
LogicalResult maybeSqueezeDims(PatternRewriter &rewriter,
72+
linalg::LinalgOp linalgOp, size_t maxDims = 2);
73+
74+
// Return if a memref with the given shape can be squeezed to the shape of
75+
// 'maxDims'. Only leading dimensions are considered squeezable.
76+
bool canSqueezeDims(llvm::ArrayRef<int64_t> shape, size_t maxDims = 2);
77+
5678
} // namespace utils
5779
} // namespace mlir
5880

lib/gc/Transforms/GPU/LinalgToXeGPU.cpp

Lines changed: 39 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -62,28 +62,6 @@ static Value createFullMask(PatternRewriter &rewriter, Location loc,
6262
return res.getResult();
6363
}
6464

65-
// Extracts the offsets from a subview operation as values.
66-
// The differense from mlir::getMixedOffsets is that this function
67-
// returns the offsets as mlir::Value that can already be used as an argument
68-
// for other mlir::Operations.
69-
static SmallVector<Value> extractOffsetsAsValues(PatternRewriter &rewriter,
70-
Location loc,
71-
memref::SubViewOp subview) {
72-
SmallVector<Value> offsetValues;
73-
auto staticOffsets = subview.getStaticOffsets();
74-
auto dynamicOffsets = subview.getOffsets();
75-
size_t dynIdx = 0;
76-
for (size_t i = 0; i < staticOffsets.size(); i++) {
77-
if (staticOffsets[i] == ShapedType::kDynamic)
78-
offsetValues.push_back(dynamicOffsets[dynIdx++]);
79-
else
80-
offsetValues.push_back(
81-
rewriter.create<arith::ConstantIndexOp>(loc, staticOffsets[i]));
82-
}
83-
84-
return offsetValues;
85-
}
86-
8765
// Max number of elements to load/store from SLM
8866
constexpr int64_t maxSLMTileSize = 32;
8967

@@ -214,7 +192,8 @@ static LogicalResult isValidMemrefOperand(linalg::LinalgOp linalgOp,
214192
linalgOp, "Expect memref operand for XeGPU lowering");
215193
}
216194

217-
if (type.getShape().size() > maxDims) {
195+
if (type.getShape().size() > maxDims &&
196+
!utils::canSqueezeDims(type.getShape(), maxDims)) {
218197
return rewriter.notifyMatchFailure(
219198
linalgOp, "Too high dimensionality for XeGPU operations");
220199
}
@@ -856,43 +835,33 @@ static SmallVector<Value> createSLMDescTiles(PatternRewriter &rewriter,
856835
auto srcType = cast<MemRefType>(src.getType());
857836
assert(srcType.getRank() == 2 && "Expected a 2D memref");
858837

859-
SmallVector<int64_t> memrefStrides;
860-
Value blockOffset;
861-
838+
SmallVector<Value> offsets;
839+
Value rootMemref;
862840
// 'imex::ConvertGPUXToSPIRVPass' doesn't allow 'memref.subview' ops in the
863841
// GPU kernel. We have to merge the subview offsets into the descriptor
864842
// offset.
865-
if (auto subView = dyn_cast<memref::SubViewOp>(src.getDefiningOp())) {
866-
auto offsets = extractOffsetsAsValues(rewriter, loc, subView);
867-
assert(offsets.size() == 2 && "Expected 2D subview offsets");
868-
869-
auto xIntOffs = offsets[0];
870-
auto yIntOffs = offsets[1];
871-
872-
// compute 'blockOffset' (beginning of the subview block in the original
873-
// flat memref)
874-
auto rowStride =
875-
cast<MemRefType>(subView.getOperand(0).getType()).getShape()[1];
876-
auto rowStrideValue =
877-
rewriter.create<arith::ConstantIndexOp>(loc, rowStride);
878-
879-
auto rowBlockOffset =
880-
rewriter.create<arith::MulIOp>(loc, xIntOffs, rowStrideValue)
881-
.getResult();
882-
blockOffset = rewriter.create<arith::AddIOp>(loc, rowBlockOffset, yIntOffs)
883-
.getResult();
843+
utils::computeSubviewOffsets(rewriter, loc, src, offsets, rootMemref);
844+
auto rootStridesFold = utils::getMemrefStrides(rewriter, loc, rootMemref);
845+
auto rootStrides =
846+
getValueOrCreateConstantIndexOp(rewriter, loc, rootStridesFold);
884847

885-
memrefStrides = {rowStride, 1};
886-
src = subView.getOperand(0);
887-
} else {
888-
// If the source is not a subview, then the blockOffset is 0
889-
blockOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
890-
memrefStrides = {srcType.getShape()[1], 1};
848+
assert(rootStrides.size() == offsets.size() &&
849+
"Expected same number of strides and offsets");
850+
851+
// blockOffset = sum(rootStrides[i] * offsets[i])
852+
Value blockOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
853+
for (size_t i = 0; i < rootStrides.size(); i++) {
854+
auto mul = rewriter.create<arith::MulIOp>(loc, rootStrides[i], offsets[i]);
855+
blockOffset = rewriter.create<arith::AddIOp>(loc, blockOffset, mul);
891856
}
892857

893-
// Scatter descriptors only work with 1D memrefs
894-
src = utils::flattenMemref(rewriter, loc, src);
858+
auto memrefStridesFold = utils::getMemrefStrides(rewriter, loc, src);
859+
auto [memrefStrides, memrefStridesDynamic] =
860+
decomposeMixedValues(memrefStridesFold);
861+
assert(memrefStridesDynamic.size() == 0 &&
862+
"Expected all values to be resolved");
895863

864+
src = utils::flattenMemref(rewriter, loc, rootMemref);
896865
return createScatterDescriptorTiles(
897866
rewriter, loc, /*flatMemref=*/src, /*loadShape2D=*/loadShape,
898867
/*tileSize2D=*/descTile, /*memrefStrides=*/memrefStrides,
@@ -1839,6 +1808,11 @@ struct ConvertGemmLikeToXeGPU : public OpRewritePattern<LinalgOpTy> {
18391808
if (failed(isOutputValid))
18401809
return isOutputValid;
18411810

1811+
if (failed(mlir::utils::maybeSqueezeDims(rewriter, gemmLikeOp))) {
1812+
return rewriter.notifyMatchFailure(
1813+
gemmLikeOp, "Failed to squeeze dimensions of GEMM-like operation");
1814+
}
1815+
18421816
// Ensure that reduction dimension tiling also works for smaller
18431817
// workloads.
18441818
auto aType = cast<ShapedType>(gemmLikeOp.getDpsInputs()[0].getType());
@@ -1894,6 +1868,12 @@ struct ConvertNamedEltwiseToXeGPU : public OpRewritePattern<LinalgOpTy> {
18941868
if (failed(isOutputValid))
18951869
return isOutputValid;
18961870

1871+
if (failed(utils::maybeSqueezeDims(rewriter, eltwiseOp))) {
1872+
return rewriter.notifyMatchFailure(
1873+
eltwiseOp,
1874+
"Could not squeeze dimensions of the elementwise operation");
1875+
}
1876+
18971877
return createEltwiseKernel(eltwiseOp, rewriter);
18981878
}
18991879

@@ -1988,6 +1968,12 @@ struct ConvertMemoryFillToXeGPU : public OpRewritePattern<LinalgOpTy> {
19881968
if (failed(isOutputValid))
19891969
return isOutputValid;
19901970

1971+
if (failed(utils::maybeSqueezeDims(rewriter, linalgOp))) {
1972+
return rewriter.notifyMatchFailure(
1973+
linalgOp,
1974+
"Could not squeeze dimensions of the memory fill operation");
1975+
}
1976+
19911977
return createMemoryFillKernel(linalgOp, rewriter);
19921978
}
19931979

lib/gc/Transforms/Utils/ValueUtils.cpp

Lines changed: 136 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include <numeric>
10+
11+
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
912
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1013
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1114
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1215
#include "mlir/Dialect/MemRef/IR/MemRef.h"
16+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1317
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1418
#include "mlir/IR/Attributes.h"
1519
#include "mlir/IR/Matchers.h"
@@ -155,9 +159,10 @@ Value flattenMemref(PatternRewriter &rewriter, Location loc, Value srcMemref) {
155159
auto srcType = cast<MemRefType>(srcMemref.getType());
156160

157161
assert(srcType && "Expected a memref type");
158-
assert(srcType.getRank() == 2 && "Expected a 2D memref");
159162

160-
int64_t flatSize = srcType.getShape()[0] * srcType.getShape()[1];
163+
auto shapeNd = srcType.getShape();
164+
int64_t flatSize =
165+
std::accumulate(shapeNd.begin(), shapeNd.end(), 1, std::multiplies<>());
161166

162167
Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
163168
Value size = rewriter.create<arith::ConstantIndexOp>(loc, flatSize);
@@ -193,5 +198,134 @@ bool hasSharedMemSpace(mlir::Value memref) {
193198
return false;
194199
}
195200

201+
void computeSubviewOffsets(PatternRewriter &rewriter, Location loc,
202+
Value memref, SmallVector<Value> &resultOffsets,
203+
Value &resultRootMemref) {
204+
auto fillVal = rewriter.create<arith::ConstantIndexOp>(loc, 0);
205+
auto type = dyn_cast<MemRefType>(memref.getType());
206+
assert(type && "Expected a memref type");
207+
208+
auto origShape = type.getShape();
209+
210+
resultOffsets.clear();
211+
resultOffsets.append(origShape.size(), fillVal);
212+
resultRootMemref = memref;
213+
214+
while (auto subViewOp = resultRootMemref.getDefiningOp<memref::SubViewOp>()) {
215+
auto currentOffsets = getAsOpFoldResult(resultOffsets);
216+
resultOffsets.clear();
217+
218+
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
219+
rewriter, resultRootMemref.getLoc(), subViewOp.getMixedOffsets(),
220+
subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), currentOffsets,
221+
resultOffsets);
222+
resultRootMemref = subViewOp.getOperand(0);
223+
}
224+
}
225+
226+
SmallVector<OpFoldResult> getMemrefStrides(PatternRewriter &rewriter,
227+
Location loc, Value memref) {
228+
auto type = dyn_cast<MemRefType>(memref.getType());
229+
230+
auto stridedLayout = dyn_cast<StridedLayoutAttr>(type.getLayout());
231+
if (stridedLayout) {
232+
auto strides = stridedLayout.getStrides();
233+
return getMixedValues(strides, {}, rewriter);
234+
}
235+
236+
auto sizes = getMixedValues(type.getShape(), {}, rewriter);
237+
auto strides = memref::computeStridesIRBlock(loc, rewriter, sizes);
238+
return strides;
239+
}
240+
241+
FailureOr<Value> reduceMemrefDims(PatternRewriter &rewriter, Location loc,
242+
Value memref, size_t maxDims = 2) {
243+
auto type = dyn_cast<MemRefType>(memref.getType());
244+
auto shape = type.getShape();
245+
246+
if (shape.size() <= maxDims)
247+
return memref;
248+
249+
for (size_t i = 0; i < shape.size() - maxDims; i++)
250+
if (shape[i] != 1)
251+
return failure();
252+
253+
auto offsets =
254+
getMixedValues(SmallVector<int64_t>(shape.size(), 0), {}, rewriter);
255+
auto sizes = getMixedValues(shape, {}, rewriter);
256+
auto staticStrides = utils::getStaticStrides(memref).value();
257+
auto strides =
258+
getMixedValues(SmallVector<int64_t>(shape.size(), 1), {}, rewriter);
259+
260+
SmallVector<int64_t> newShape(shape.begin() + shape.size() - maxDims,
261+
shape.end());
262+
SmallVector<int64_t> newStrides(
263+
staticStrides.begin() + shape.size() - maxDims, staticStrides.end());
264+
265+
int64_t newOffset = 0;
266+
if (auto memrefLayout = dyn_cast<StridedLayoutAttr>(type.getLayout()))
267+
newOffset = memrefLayout.getOffset();
268+
269+
auto newLayout = StridedLayoutAttr::get(
270+
rewriter.getContext(), /*offset=*/newOffset, /*strides=*/newStrides);
271+
MemRefType newMemRefType = MemRefType::get(newShape, type.getElementType(),
272+
newLayout, type.getMemorySpace());
273+
274+
auto squeezedSubview =
275+
rewriter
276+
.create<memref::SubViewOp>(loc, newMemRefType, memref, offsets, sizes,
277+
strides)
278+
.getResult();
279+
return squeezedSubview;
280+
}
281+
282+
LogicalResult maybeSqueezeDims(PatternRewriter &rewriter,
283+
linalg::LinalgOp linalgOp, size_t maxDims) {
284+
SmallVector<std::pair<size_t, Value>> newOperands;
285+
auto operands = linalgOp->getOperands();
286+
auto loc = linalgOp.getLoc();
287+
288+
for (size_t i = 0; i < operands.size(); i++) {
289+
auto operand = operands[i];
290+
auto type = dyn_cast<MemRefType>(operand.getType());
291+
if (!type) {
292+
// Skip non-memref operands
293+
continue;
294+
}
295+
296+
if (type.getShape().size() <= maxDims)
297+
continue;
298+
299+
auto res = reduceMemrefDims(rewriter, loc, operand, maxDims);
300+
if (failed(res)) {
301+
return rewriter.notifyMatchFailure(
302+
linalgOp, "Can't squeeze memref to the desired number of dimensions");
303+
}
304+
305+
auto flatSubview = res.value();
306+
newOperands.emplace_back(i, flatSubview);
307+
}
308+
309+
if (newOperands.empty())
310+
return success();
311+
312+
rewriter.modifyOpInPlace(linalgOp, [&] {
313+
for (auto [i, operand] : newOperands)
314+
linalgOp->setOperand(i, operand);
315+
});
316+
return success();
317+
}
318+
319+
bool canSqueezeDims(llvm::ArrayRef<int64_t> shape, size_t maxDims) {
320+
if (shape.size() <= maxDims)
321+
return true;
322+
323+
for (size_t i = 0; i < shape.size() - maxDims; i++)
324+
if (shape[i] != 1)
325+
return false;
326+
327+
return true;
328+
}
329+
196330
} // namespace utils
197331
} // namespace mlir
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// RUN: gc-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file -cse | FileCheck %s
2+
3+
!input_type = memref<2x4x8x16xf16>
4+
!chunk_type = memref<1x1x8x16xf16, strided<[512, 128, 16, 1], offset: ?>>
5+
!slm_chunk = memref<1x1x8x16xf16, strided<[128, 128, 16, 1], offset: ?>, 3>
6+
7+
// The map that computes an offset for SLM
8+
// CHECK: #map = affine_map<(d0, d1) -> (d0 * 4 + d1)>
9+
#map = affine_map<(xi, yi) -> (xi * 4 + yi)>
10+
11+
func.func @entry(%arg0: !input_type, %arg1: !input_type, %arg2: !input_type) {
12+
%c1 = arith.constant 1 : index
13+
%c2 = arith.constant 2 : index
14+
%c4 = arith.constant 4 : index
15+
16+
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c1, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c2, %arg13 = %c4, %arg14 = %c1) {
17+
// CHECK: %[[ARG0_SB:.+]] = memref.subview %arg0[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1]
18+
%arg0_sb = memref.subview %arg0[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : !input_type to !chunk_type
19+
// CHECK: %[[ARG1_SB:.+]] = memref.subview %arg1[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1]
20+
%arg1_sb = memref.subview %arg1[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : !input_type to !chunk_type
21+
// CHECK: %[[ARG2_SB:.+]] = memref.subview %arg2[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1]
22+
%arg2_sb = memref.subview %arg2[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : !input_type to !chunk_type
23+
24+
// CHECK: %[[SLM_BUFF:.+]] = memref.alloc() : memref<8x1x8x16xf16, 3>
25+
%slm_root = memref.alloc() : memref<8x1x8x16xf16, 3>
26+
27+
%slm_idx = affine.apply #map(%arg6, %arg7)
28+
%slm = memref.subview %slm_root[%slm_idx, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : memref<8x1x8x16xf16, 3> to !slm_chunk
29+
30+
// Squeezing the arguments of 'linalg.mul'
31+
// CHECK: %[[ARG0_SQUEEZ:.+]] = memref.subview %[[ARG0_SB]][0, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] :
32+
// CHECK-SAME: memref<1x1x8x16xf16, strided<[512, 128, 16, 1], offset: ?>> to memref<8x16xf16, strided<[16, 1], offset: ?>>
33+
34+
// CHECK: %[[ARG1_SQUEEZ:.+]] = memref.subview %[[ARG1_SB]][0, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] :
35+
// CHECK-SAME: memref<1x1x8x16xf16, strided<[512, 128, 16, 1], offset: ?>> to memref<8x16xf16, strided<[16, 1], offset: ?>>
36+
37+
// Verify that tensor descriptors are created from the squeezed memrefs
38+
// CHECK: xegpu.create_nd_tdesc %[[ARG0_SQUEEZ]]
39+
// CHECK: xegpu.create_nd_tdesc %[[ARG1_SQUEEZ]]
40+
41+
// Verify that the SLM output of linalg.mul is squeezed correctly
42+
// CHECK-NOT: .* = memref.subview %[[SLM_BUFF]] .*
43+
// CHECK: %[[SLM_THREAD_OFF:.+]] = affine.apply #map(%arg6, %arg7)
44+
// CHECK: %[[SLM_OFF:.+]] = arith.muli %[[SLM_THREAD_OFF]], %c128 : index
45+
// CHECK: %[[FLAT_SLM:.+]] = memref.reinterpret_cast %[[SLM_BUFF]] to offset: [%c0], sizes: [%c1024], strides: [%c1] : memref<8x1x8x16xf16, 3> to memref<1024xf16, 3>
46+
// CHECK: xegpu.create_tdesc %[[FLAT_SLM]]
47+
linalg.mul ins(%arg0_sb, %arg1_sb : !chunk_type, !chunk_type) outs(%slm : !slm_chunk)
48+
49+
// Squeezing the result buffer of 'linalg.add'
50+
// CHECK: %[[ARG2_SQUEEZ:.+]] = memref.subview %[[ARG2_SB]][0, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] :
51+
// CHECK-SAME: memref<1x1x8x16xf16, strided<[512, 128, 16, 1], offset: ?>> to memref<8x16xf16, strided<[16, 1], offset: ?>>
52+
53+
// Verify that tensor descriptors are created from the squeezed memrefs
54+
// CHECK: xegpu.create_nd_tdesc %[[ARG2_SQUEEZ]]
55+
linalg.add ins(%arg0_sb, %slm : !chunk_type, !slm_chunk) outs(%arg2_sb : !chunk_type)
56+
57+
gpu.terminator
58+
} {SCFToGPU_visited}
59+
60+
return
61+
}

0 commit comments

Comments
 (0)