Skip to content

Commit 98728d9

Browse files
authored
[MLIR][XeGPU] Add lowering from transfer_read/transfer_write to load_gather/store_scatter (#152429)
Lowering transfer_read/transfer_write to load_gather/store_scatter in case the target uArch doesn't support load_nd/store_nd. The high level steps: 1. compute Strides; 2. compute Offsets; 3. collapseMemrefTo1D; 4. create Load gather or store_scatter op
1 parent 37cc010 commit 98728d9

File tree

9 files changed

+854
-233
lines changed

9 files changed

+854
-233
lines changed

mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@ Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
123123
void doSCFStructuralTypeConversionWithTensorType(Operation *op,
124124
TypeConverter converter);
125125

126+
/// Retrieves the chip string from the XeVM target attribute of the parent
127+
/// GPU module operation. Returns the chip identifier if found, or nullopt
128+
/// if no GPU module parent or XeVM target attribute exists.
129+
std::optional<std::string> getChipStr(Operation *op);
130+
126131
} // namespace xegpu
127132

128133
} // namespace mlir

mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ add_mlir_conversion_library(MLIRVectorToXeGPU
1313
MLIRTransforms
1414
MLIRVectorDialect
1515
MLIRXeGPUDialect
16+
MLIRXeGPUUtils
1617
)

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 306 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414

1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/MemRef/IR/MemRef.h"
17+
#include "mlir/Dialect/Utils/IndexingUtils.h"
1718
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
1819
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1920
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
21+
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
2022
#include "mlir/Pass/Pass.h"
2123
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2224
#include "llvm/ADT/TypeSwitch.h"
@@ -68,18 +70,14 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
6870
if (!srcTy)
6971
return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
7072

71-
// Perform common data transfer checks.
72-
VectorType vecTy = xferOp.getVectorType();
73-
if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy)))
74-
return failure();
75-
7673
// Validate further transfer op semantics.
7774
SmallVector<int64_t> strides;
7875
int64_t offset;
7976
if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
8077
return rewriter.notifyMatchFailure(
8178
xferOp, "Buffer must be contiguous in the innermost dimension");
8279

80+
VectorType vecTy = xferOp.getVectorType();
8381
unsigned vecRank = vecTy.getRank();
8482
if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
8583
return rewriter.notifyMatchFailure(
@@ -155,6 +153,277 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
155153
return ndDesc;
156154
}
157155

156+
// Adjusts the strides of a memref according to a given permutation map for
157+
// vector operations.
158+
//
159+
// This function updates the innermost strides in the `strides` array to
160+
// reflect the permutation specified by `permMap`. The permutation is computed
161+
// using the inverse and broadcasting-aware version of the permutation map,
162+
// and is applied to the relevant strides. This ensures that memory accesses
163+
// are consistent with the logical permutation of vector elements.
164+
//
165+
// Example:
166+
// Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]`.
167+
// If the permutation map swaps the last two dimensions (e.g., [0, 1] -> [1,
168+
// 0]), then after calling this function, the last two strides will be
169+
// swapped:
170+
// Original strides: [s0, s1, s2, s3]
171+
// After permutation: [s0, s1, s3, s2]
172+
//
173+
static void adjustStridesForPermutation(AffineMap permMap,
174+
SmallVectorImpl<Value> &strides) {
175+
176+
AffineMap invMap = inverseAndBroadcastProjectedPermutation(permMap);
177+
SmallVector<unsigned> perms;
178+
invMap.isPermutationOfMinorIdentityWithBroadcasting(perms);
179+
SmallVector<int64_t> perms64(perms.begin(), perms.end());
180+
strides = applyPermutation(strides, perms64);
181+
}
182+
183+
// Computes memory strides for vector transfer operations, handling both
184+
// static and dynamic memrefs while applying permutation transformations
185+
// for XeGPU lowering.
186+
static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
187+
PatternRewriter &rewriter) {
188+
SmallVector<Value> strides;
189+
Value baseMemref = xferOp.getBase();
190+
AffineMap permMap = xferOp.getPermutationMap();
191+
MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
192+
193+
Location loc = xferOp.getLoc();
194+
if (memrefType.hasStaticShape()) {
195+
int64_t offset;
196+
SmallVector<int64_t> intStrides;
197+
if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
198+
return {};
199+
// Wrap static strides as MLIR values
200+
for (int64_t s : intStrides)
201+
strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
202+
} else {
203+
// For dynamic shape memref, use memref.extract_strided_metadata to get
204+
// stride values
205+
unsigned rank = memrefType.getRank();
206+
Type indexType = rewriter.getIndexType();
207+
208+
// Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
209+
// size0, size1, ..., sizeN-1]
210+
SmallVector<Type> resultTypes;
211+
resultTypes.push_back(MemRefType::get(
212+
{}, memrefType.getElementType())); // base memref (unranked)
213+
resultTypes.push_back(indexType); // offset
214+
215+
for (unsigned i = 0; i < rank; ++i)
216+
resultTypes.push_back(indexType); // strides
217+
218+
for (unsigned i = 0; i < rank; ++i)
219+
resultTypes.push_back(indexType); // sizes
220+
221+
auto meta = memref::ExtractStridedMetadataOp::create(
222+
rewriter, loc, resultTypes, baseMemref);
223+
strides.append(meta.getStrides().begin(), meta.getStrides().end());
224+
}
225+
// Adjust strides according to the permutation map (e.g., for transpose)
226+
adjustStridesForPermutation(permMap, strides);
227+
return strides;
228+
}
229+
230+
// This function compute the vectors of localOffsets for scattered load/stores.
231+
// It is used in the lowering of vector.transfer_read/write to
232+
// load_gather/store_scatter Example:
233+
// %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0],
234+
// %cst {in_bounds = [true, true, true, true]}>} :
235+
// memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
236+
//
237+
// %6 = vector.step: vector<4xindex>
238+
// %7 = vector.step: vector<2xindex>
239+
// %8 = vector.step: vector<6xindex>
240+
// %9 = vector.step: vector<32xindex>
241+
// %10 = arith.mul %6, 384
242+
// %11 = arith.mul %7, 192
243+
// %12 = arith.mul %8, 32
244+
// %13 = arith.mul %9, 1
245+
// %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
246+
// %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
247+
// %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
248+
// %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
249+
// %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>
250+
// %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>
251+
// %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>
252+
// %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>
253+
// %22 = arith.add %18, %19
254+
// %23 = arith.add %20, %21
255+
// %local_offsets = arith.add %22, %23
256+
// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
257+
// %offsets = orig_offset + local_offsets
258+
static Value computeOffsets(VectorTransferOpInterface xferOp,
259+
PatternRewriter &rewriter,
260+
ArrayRef<Value> strides) {
261+
Location loc = xferOp.getLoc();
262+
VectorType vectorType = xferOp.getVectorType();
263+
SmallVector<Value> indices(xferOp.getIndices().begin(),
264+
xferOp.getIndices().end());
265+
ArrayRef<int64_t> vectorShape = vectorType.getShape();
266+
267+
// Create vector.step operations for each dimension
268+
SmallVector<Value> stepVectors;
269+
llvm::map_to_vector(vectorShape, [&](int64_t dim) {
270+
auto stepType = VectorType::get({dim}, rewriter.getIndexType());
271+
auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
272+
stepVectors.push_back(stepOp);
273+
return stepOp;
274+
});
275+
276+
// Multiply step vectors by corresponding strides
277+
size_t memrefRank = strides.size();
278+
size_t vectorRank = vectorShape.size();
279+
SmallVector<Value> strideMultiplied;
280+
for (size_t i = 0; i < vectorRank; ++i) {
281+
size_t memrefDim = memrefRank - vectorRank + i;
282+
Value strideValue = strides[memrefDim];
283+
auto mulType = dyn_cast<VectorType>(stepVectors[i].getType());
284+
auto bcastOp =
285+
vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
286+
auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
287+
strideMultiplied.push_back(mulOp);
288+
}
289+
290+
// Shape cast each multiplied vector to add singleton dimensions
291+
SmallVector<Value> shapeCasted;
292+
for (size_t i = 0; i < vectorRank; ++i) {
293+
SmallVector<int64_t> newShape(vectorRank, 1);
294+
newShape[i] = vectorShape[i];
295+
auto newType = VectorType::get(newShape, rewriter.getIndexType());
296+
auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
297+
strideMultiplied[i]);
298+
shapeCasted.push_back(castOp);
299+
}
300+
301+
// Broadcast each shape-casted vector to full vector shape
302+
SmallVector<Value> broadcasted;
303+
auto fullIndexVectorType =
304+
VectorType::get(vectorShape, rewriter.getIndexType());
305+
for (Value shapeCastVal : shapeCasted) {
306+
auto broadcastOp = vector::BroadcastOp::create(
307+
rewriter, loc, fullIndexVectorType, shapeCastVal);
308+
broadcasted.push_back(broadcastOp);
309+
}
310+
311+
// Add all broadcasted vectors together to compute local offsets
312+
Value localOffsets = broadcasted[0];
313+
for (size_t i = 1; i < broadcasted.size(); ++i)
314+
localOffsets =
315+
arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
316+
317+
// Compute base offset from transfer read indices
318+
Value baseOffset = nullptr;
319+
if (!indices.empty()) {
320+
baseOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
321+
for (size_t i = 0; i < indices.size(); ++i) {
322+
Value strideVal = strides[i];
323+
Value offsetContrib =
324+
arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
325+
baseOffset =
326+
arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
327+
}
328+
// Broadcast base offset to match vector shape
329+
Value bcastBase = vector::BroadcastOp::create(
330+
rewriter, loc, fullIndexVectorType, baseOffset);
331+
localOffsets =
332+
arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
333+
}
334+
return localOffsets;
335+
}
336+
337+
// Collapse memref shape to 1D
338+
static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp,
339+
PatternRewriter &rewriter) {
340+
Location loc = xferOp.getLoc();
341+
342+
Value baseMemref = xferOp.getBase();
343+
MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
344+
Type elementType = memrefType.getElementType();
345+
346+
// Compute the total number of elements in the memref
347+
MemRefType flatMemrefType;
348+
if (memrefType.hasStaticShape()) {
349+
auto totalElements = memrefType.getNumElements();
350+
flatMemrefType = MemRefType::get({totalElements}, elementType);
351+
} else {
352+
flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType);
353+
}
354+
355+
SmallVector<ReassociationIndices> reassociation;
356+
ReassociationIndices allDims =
357+
llvm::to_vector(llvm::seq<int64_t>(0, memrefType.getRank()));
358+
reassociation.push_back(allDims);
359+
360+
auto collapseOp = memref::CollapseShapeOp::create(
361+
rewriter, loc, flatMemrefType, baseMemref, reassociation);
362+
return collapseOp;
363+
}
364+
365+
static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
366+
PatternRewriter &rewriter) {
367+
368+
Location loc = readOp.getLoc();
369+
VectorType vectorType = readOp.getVectorType();
370+
ArrayRef<int64_t> vectorShape = vectorType.getShape();
371+
auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
372+
if (!memrefType)
373+
return rewriter.notifyMatchFailure(readOp, "Expected memref source");
374+
375+
SmallVector<Value> strides = computeStrides(readOp, rewriter);
376+
if (strides.empty())
377+
return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
378+
379+
Value localOffsets = computeOffsets(readOp, rewriter, strides);
380+
381+
Value flatMemref = collapseMemrefTo1D(readOp, rewriter);
382+
383+
Value mask = vector::ConstantMaskOp::create(
384+
rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
385+
vectorShape);
386+
auto gatherOp = xegpu::LoadGatherOp::create(
387+
rewriter, loc, vectorType, flatMemref, localOffsets, mask,
388+
/*chunk_size=*/IntegerAttr{},
389+
/*l1_hint=*/xegpu::CachePolicyAttr{},
390+
/*l2_hint=*/xegpu::CachePolicyAttr{},
391+
/*l3_hint=*/xegpu::CachePolicyAttr{});
392+
393+
rewriter.replaceOp(readOp, gatherOp.getResult());
394+
return success();
395+
}
396+
397+
static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
398+
PatternRewriter &rewriter) {
399+
400+
Location loc = writeOp.getLoc();
401+
VectorType vectorType = writeOp.getVectorType();
402+
ArrayRef<int64_t> vectorShape = vectorType.getShape();
403+
404+
auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
405+
if (!memrefType)
406+
return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
407+
408+
SmallVector<Value> strides = computeStrides(writeOp, rewriter);
409+
410+
Value localOffsets = computeOffsets(writeOp, rewriter, strides);
411+
412+
Value flatMemref = collapseMemrefTo1D(writeOp, rewriter);
413+
414+
Value mask = vector::ConstantMaskOp::create(
415+
rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
416+
vectorShape);
417+
xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
418+
localOffsets, mask,
419+
/*chunk_size=*/IntegerAttr{},
420+
/*l1_hint=*/xegpu::CachePolicyAttr{},
421+
/*l2_hint=*/xegpu::CachePolicyAttr{},
422+
/*l3_hint=*/xegpu::CachePolicyAttr{});
423+
rewriter.eraseOp(writeOp);
424+
return success();
425+
}
426+
158427
struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
159428
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
160429

@@ -165,6 +434,22 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
165434
if (failed(transferPreconditions(rewriter, readOp)))
166435
return failure();
167436

437+
// TODO:This check needs to be replaced with proper uArch capability check
438+
auto chip = xegpu::getChipStr(readOp);
439+
if (chip != "pvc" && chip != "bmg") {
440+
// lower to scattered load Op if the target HW doesn't have 2d block load
441+
// support
442+
// TODO: add support for OutOfBound access
443+
if (readOp.hasOutOfBoundsDim())
444+
return failure();
445+
return lowerToScatteredLoadOp(readOp, rewriter);
446+
}
447+
448+
// Perform common data transfer checks.
449+
VectorType vecTy = readOp.getVectorType();
450+
if (failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
451+
return failure();
452+
168453
bool isOutOfBounds = readOp.hasOutOfBoundsDim();
169454
if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
170455
return rewriter.notifyMatchFailure(
@@ -173,7 +458,6 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
173458
AffineMap readMap = readOp.getPermutationMap();
174459
bool isTransposeLoad = !readMap.isMinorIdentity();
175460

176-
VectorType vecTy = readOp.getVectorType();
177461
Type elementType = vecTy.getElementType();
178462
unsigned minTransposeBitWidth = 32;
179463
if (isTransposeLoad &&
@@ -221,11 +505,26 @@ struct TransferWriteLowering
221505
if (failed(transferPreconditions(rewriter, writeOp)))
222506
return failure();
223507

508+
// TODO:This check needs to be replaced with proper uArch capability check
509+
auto chip = xegpu::getChipStr(writeOp);
510+
if (chip != "pvc" && chip != "bmg") {
511+
// lower to scattered store Op if the target HW doesn't have 2d block
512+
// store support
513+
// TODO: add support for OutOfBound access
514+
if (writeOp.hasOutOfBoundsDim())
515+
return failure();
516+
return lowerToScatteredStoreOp(writeOp, rewriter);
517+
}
518+
519+
// Perform common data transfer checks.
520+
VectorType vecTy = writeOp.getVectorType();
521+
if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy)))
522+
return failure();
523+
224524
AffineMap map = writeOp.getPermutationMap();
225525
if (!map.isMinorIdentity())
226526
return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
227527

228-
VectorType vecTy = writeOp.getVectorType();
229528
auto descType = xegpu::TensorDescType::get(
230529
vecTy.getShape(), vecTy.getElementType(),
231530
/*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),

0 commit comments

Comments
 (0)