14
14
15
15
#include " mlir/Dialect/Arith/IR/Arith.h"
16
16
#include " mlir/Dialect/MemRef/IR/MemRef.h"
17
+ #include " mlir/Dialect/Utils/IndexingUtils.h"
17
18
#include " mlir/Dialect/Utils/StructuredOpsUtils.h"
18
19
#include " mlir/Dialect/Vector/IR/VectorOps.h"
19
20
#include " mlir/Dialect/XeGPU/IR/XeGPU.h"
21
+ #include " mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
20
22
#include " mlir/Pass/Pass.h"
21
23
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
22
24
#include " llvm/ADT/TypeSwitch.h"
@@ -68,18 +70,14 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
68
70
if (!srcTy)
69
71
return rewriter.notifyMatchFailure (xferOp, " Expects memref source" );
70
72
71
- // Perform common data transfer checks.
72
- VectorType vecTy = xferOp.getVectorType ();
73
- if (failed (storeLoadPreconditions (rewriter, xferOp, vecTy)))
74
- return failure ();
75
-
76
73
// Validate further transfer op semantics.
77
74
SmallVector<int64_t > strides;
78
75
int64_t offset;
79
76
if (failed (srcTy.getStridesAndOffset (strides, offset)) || strides.back () != 1 )
80
77
return rewriter.notifyMatchFailure (
81
78
xferOp, " Buffer must be contiguous in the innermost dimension" );
82
79
80
+ VectorType vecTy = xferOp.getVectorType ();
83
81
unsigned vecRank = vecTy.getRank ();
84
82
if (xferOp.hasOutOfBoundsDim () && vecRank < 2 )
85
83
return rewriter.notifyMatchFailure (
@@ -155,6 +153,277 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
155
153
return ndDesc;
156
154
}
157
155
156
+ // Adjusts the strides of a memref according to a given permutation map for
157
+ // vector operations.
158
+ //
159
+ // This function updates the innermost strides in the `strides` array to
160
+ // reflect the permutation specified by `permMap`. The permutation is computed
161
+ // using the inverse and broadcasting-aware version of the permutation map,
162
+ // and is applied to the relevant strides. This ensures that memory accesses
163
+ // are consistent with the logical permutation of vector elements.
164
+ //
165
+ // Example:
166
+ // Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]`.
167
+ // If the permutation map swaps the last two dimensions (e.g., [0, 1] -> [1,
168
+ // 0]), then after calling this function, the last two strides will be
169
+ // swapped:
170
+ // Original strides: [s0, s1, s2, s3]
171
+ // After permutation: [s0, s1, s3, s2]
172
+ //
173
+ static void adjustStridesForPermutation (AffineMap permMap,
174
+ SmallVectorImpl<Value> &strides) {
175
+
176
+ AffineMap invMap = inverseAndBroadcastProjectedPermutation (permMap);
177
+ SmallVector<unsigned > perms;
178
+ invMap.isPermutationOfMinorIdentityWithBroadcasting (perms);
179
+ SmallVector<int64_t > perms64 (perms.begin (), perms.end ());
180
+ strides = applyPermutation (strides, perms64);
181
+ }
182
+
183
+ // Computes memory strides for vector transfer operations, handling both
184
+ // static and dynamic memrefs while applying permutation transformations
185
+ // for XeGPU lowering.
186
+ static SmallVector<Value> computeStrides (VectorTransferOpInterface xferOp,
187
+ PatternRewriter &rewriter) {
188
+ SmallVector<Value> strides;
189
+ Value baseMemref = xferOp.getBase ();
190
+ AffineMap permMap = xferOp.getPermutationMap ();
191
+ MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType ());
192
+
193
+ Location loc = xferOp.getLoc ();
194
+ if (memrefType.hasStaticShape ()) {
195
+ int64_t offset;
196
+ SmallVector<int64_t > intStrides;
197
+ if (failed (memrefType.getStridesAndOffset (intStrides, offset)))
198
+ return {};
199
+ // Wrap static strides as MLIR values
200
+ for (int64_t s : intStrides)
201
+ strides.push_back (arith::ConstantIndexOp::create (rewriter, loc, s));
202
+ } else {
203
+ // For dynamic shape memref, use memref.extract_strided_metadata to get
204
+ // stride values
205
+ unsigned rank = memrefType.getRank ();
206
+ Type indexType = rewriter.getIndexType ();
207
+
208
+ // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
209
+ // size0, size1, ..., sizeN-1]
210
+ SmallVector<Type> resultTypes;
211
+ resultTypes.push_back (MemRefType::get (
212
+ {}, memrefType.getElementType ())); // base memref (unranked)
213
+ resultTypes.push_back (indexType); // offset
214
+
215
+ for (unsigned i = 0 ; i < rank; ++i)
216
+ resultTypes.push_back (indexType); // strides
217
+
218
+ for (unsigned i = 0 ; i < rank; ++i)
219
+ resultTypes.push_back (indexType); // sizes
220
+
221
+ auto meta = memref::ExtractStridedMetadataOp::create (
222
+ rewriter, loc, resultTypes, baseMemref);
223
+ strides.append (meta.getStrides ().begin (), meta.getStrides ().end ());
224
+ }
225
+ // Adjust strides according to the permutation map (e.g., for transpose)
226
+ adjustStridesForPermutation (permMap, strides);
227
+ return strides;
228
+ }
229
+
230
+ // This function compute the vectors of localOffsets for scattered load/stores.
231
+ // It is used in the lowering of vector.transfer_read/write to
232
+ // load_gather/store_scatter Example:
233
+ // %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0],
234
+ // %cst {in_bounds = [true, true, true, true]}>} :
235
+ // memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
236
+ //
237
+ // %6 = vector.step: vector<4xindex>
238
+ // %7 = vector.step: vector<2xindex>
239
+ // %8 = vector.step: vector<6xindex>
240
+ // %9 = vector.step: vector<32xindex>
241
+ // %10 = arith.mul %6, 384
242
+ // %11 = arith.mul %7, 192
243
+ // %12 = arith.mul %8, 32
244
+ // %13 = arith.mul %9, 1
245
+ // %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
246
+ // %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
247
+ // %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
248
+ // %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
249
+ // %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>
250
+ // %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>
251
+ // %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>
252
+ // %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>
253
+ // %22 = arith.add %18, %19
254
+ // %23 = arith.add %20, %21
255
+ // %local_offsets = arith.add %22, %23
256
+ // %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
257
+ // %offsets = orig_offset + local_offsets
258
+ static Value computeOffsets (VectorTransferOpInterface xferOp,
259
+ PatternRewriter &rewriter,
260
+ ArrayRef<Value> strides) {
261
+ Location loc = xferOp.getLoc ();
262
+ VectorType vectorType = xferOp.getVectorType ();
263
+ SmallVector<Value> indices (xferOp.getIndices ().begin (),
264
+ xferOp.getIndices ().end ());
265
+ ArrayRef<int64_t > vectorShape = vectorType.getShape ();
266
+
267
+ // Create vector.step operations for each dimension
268
+ SmallVector<Value> stepVectors;
269
+ llvm::map_to_vector (vectorShape, [&](int64_t dim) {
270
+ auto stepType = VectorType::get ({dim}, rewriter.getIndexType ());
271
+ auto stepOp = vector::StepOp::create (rewriter, loc, stepType);
272
+ stepVectors.push_back (stepOp);
273
+ return stepOp;
274
+ });
275
+
276
+ // Multiply step vectors by corresponding strides
277
+ size_t memrefRank = strides.size ();
278
+ size_t vectorRank = vectorShape.size ();
279
+ SmallVector<Value> strideMultiplied;
280
+ for (size_t i = 0 ; i < vectorRank; ++i) {
281
+ size_t memrefDim = memrefRank - vectorRank + i;
282
+ Value strideValue = strides[memrefDim];
283
+ auto mulType = dyn_cast<VectorType>(stepVectors[i].getType ());
284
+ auto bcastOp =
285
+ vector::BroadcastOp::create (rewriter, loc, mulType, strideValue);
286
+ auto mulOp = arith::MulIOp::create (rewriter, loc, stepVectors[i], bcastOp);
287
+ strideMultiplied.push_back (mulOp);
288
+ }
289
+
290
+ // Shape cast each multiplied vector to add singleton dimensions
291
+ SmallVector<Value> shapeCasted;
292
+ for (size_t i = 0 ; i < vectorRank; ++i) {
293
+ SmallVector<int64_t > newShape (vectorRank, 1 );
294
+ newShape[i] = vectorShape[i];
295
+ auto newType = VectorType::get (newShape, rewriter.getIndexType ());
296
+ auto castOp = vector::ShapeCastOp::create (rewriter, loc, newType,
297
+ strideMultiplied[i]);
298
+ shapeCasted.push_back (castOp);
299
+ }
300
+
301
+ // Broadcast each shape-casted vector to full vector shape
302
+ SmallVector<Value> broadcasted;
303
+ auto fullIndexVectorType =
304
+ VectorType::get (vectorShape, rewriter.getIndexType ());
305
+ for (Value shapeCastVal : shapeCasted) {
306
+ auto broadcastOp = vector::BroadcastOp::create (
307
+ rewriter, loc, fullIndexVectorType, shapeCastVal);
308
+ broadcasted.push_back (broadcastOp);
309
+ }
310
+
311
+ // Add all broadcasted vectors together to compute local offsets
312
+ Value localOffsets = broadcasted[0 ];
313
+ for (size_t i = 1 ; i < broadcasted.size (); ++i)
314
+ localOffsets =
315
+ arith::AddIOp::create (rewriter, loc, localOffsets, broadcasted[i]);
316
+
317
+ // Compute base offset from transfer read indices
318
+ Value baseOffset = nullptr ;
319
+ if (!indices.empty ()) {
320
+ baseOffset = arith::ConstantIndexOp::create (rewriter, loc, 0 );
321
+ for (size_t i = 0 ; i < indices.size (); ++i) {
322
+ Value strideVal = strides[i];
323
+ Value offsetContrib =
324
+ arith::MulIOp::create (rewriter, loc, indices[i], strideVal);
325
+ baseOffset =
326
+ arith::AddIOp::create (rewriter, loc, baseOffset, offsetContrib);
327
+ }
328
+ // Broadcast base offset to match vector shape
329
+ Value bcastBase = vector::BroadcastOp::create (
330
+ rewriter, loc, fullIndexVectorType, baseOffset);
331
+ localOffsets =
332
+ arith::AddIOp::create (rewriter, loc, bcastBase, localOffsets);
333
+ }
334
+ return localOffsets;
335
+ }
336
+
337
+ // Collapse memref shape to 1D
338
+ static Value collapseMemrefTo1D (VectorTransferOpInterface xferOp,
339
+ PatternRewriter &rewriter) {
340
+ Location loc = xferOp.getLoc ();
341
+
342
+ Value baseMemref = xferOp.getBase ();
343
+ MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType ());
344
+ Type elementType = memrefType.getElementType ();
345
+
346
+ // Compute the total number of elements in the memref
347
+ MemRefType flatMemrefType;
348
+ if (memrefType.hasStaticShape ()) {
349
+ auto totalElements = memrefType.getNumElements ();
350
+ flatMemrefType = MemRefType::get ({totalElements}, elementType);
351
+ } else {
352
+ flatMemrefType = MemRefType::get ({ShapedType::kDynamic }, elementType);
353
+ }
354
+
355
+ SmallVector<ReassociationIndices> reassociation;
356
+ ReassociationIndices allDims =
357
+ llvm::to_vector (llvm::seq<int64_t >(0 , memrefType.getRank ()));
358
+ reassociation.push_back (allDims);
359
+
360
+ auto collapseOp = memref::CollapseShapeOp::create (
361
+ rewriter, loc, flatMemrefType, baseMemref, reassociation);
362
+ return collapseOp;
363
+ }
364
+
365
+ static LogicalResult lowerToScatteredLoadOp (vector::TransferReadOp readOp,
366
+ PatternRewriter &rewriter) {
367
+
368
+ Location loc = readOp.getLoc ();
369
+ VectorType vectorType = readOp.getVectorType ();
370
+ ArrayRef<int64_t > vectorShape = vectorType.getShape ();
371
+ auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType ());
372
+ if (!memrefType)
373
+ return rewriter.notifyMatchFailure (readOp, " Expected memref source" );
374
+
375
+ SmallVector<Value> strides = computeStrides (readOp, rewriter);
376
+ if (strides.empty ())
377
+ return rewriter.notifyMatchFailure (readOp, " Failed to compute strides" );
378
+
379
+ Value localOffsets = computeOffsets (readOp, rewriter, strides);
380
+
381
+ Value flatMemref = collapseMemrefTo1D (readOp, rewriter);
382
+
383
+ Value mask = vector::ConstantMaskOp::create (
384
+ rewriter, loc, VectorType::get (vectorShape, rewriter.getI1Type ()),
385
+ vectorShape);
386
+ auto gatherOp = xegpu::LoadGatherOp::create (
387
+ rewriter, loc, vectorType, flatMemref, localOffsets, mask,
388
+ /* chunk_size=*/ IntegerAttr{},
389
+ /* l1_hint=*/ xegpu::CachePolicyAttr{},
390
+ /* l2_hint=*/ xegpu::CachePolicyAttr{},
391
+ /* l3_hint=*/ xegpu::CachePolicyAttr{});
392
+
393
+ rewriter.replaceOp (readOp, gatherOp.getResult ());
394
+ return success ();
395
+ }
396
+
397
+ static LogicalResult lowerToScatteredStoreOp (vector::TransferWriteOp writeOp,
398
+ PatternRewriter &rewriter) {
399
+
400
+ Location loc = writeOp.getLoc ();
401
+ VectorType vectorType = writeOp.getVectorType ();
402
+ ArrayRef<int64_t > vectorShape = vectorType.getShape ();
403
+
404
+ auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType ());
405
+ if (!memrefType)
406
+ return rewriter.notifyMatchFailure (writeOp, " Expected memref source" );
407
+
408
+ SmallVector<Value> strides = computeStrides (writeOp, rewriter);
409
+
410
+ Value localOffsets = computeOffsets (writeOp, rewriter, strides);
411
+
412
+ Value flatMemref = collapseMemrefTo1D (writeOp, rewriter);
413
+
414
+ Value mask = vector::ConstantMaskOp::create (
415
+ rewriter, loc, VectorType::get (vectorShape, rewriter.getI1Type ()),
416
+ vectorShape);
417
+ xegpu::StoreScatterOp::create (rewriter, loc, writeOp.getVector (), flatMemref,
418
+ localOffsets, mask,
419
+ /* chunk_size=*/ IntegerAttr{},
420
+ /* l1_hint=*/ xegpu::CachePolicyAttr{},
421
+ /* l2_hint=*/ xegpu::CachePolicyAttr{},
422
+ /* l3_hint=*/ xegpu::CachePolicyAttr{});
423
+ rewriter.eraseOp (writeOp);
424
+ return success ();
425
+ }
426
+
158
427
struct TransferReadLowering : public OpRewritePattern <vector::TransferReadOp> {
159
428
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
160
429
@@ -165,6 +434,22 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
165
434
if (failed (transferPreconditions (rewriter, readOp)))
166
435
return failure ();
167
436
437
+ // TODO:This check needs to be replaced with proper uArch capability check
438
+ auto chip = xegpu::getChipStr (readOp);
439
+ if (chip != " pvc" && chip != " bmg" ) {
440
+ // lower to scattered load Op if the target HW doesn't have 2d block load
441
+ // support
442
+ // TODO: add support for OutOfBound access
443
+ if (readOp.hasOutOfBoundsDim ())
444
+ return failure ();
445
+ return lowerToScatteredLoadOp (readOp, rewriter);
446
+ }
447
+
448
+ // Perform common data transfer checks.
449
+ VectorType vecTy = readOp.getVectorType ();
450
+ if (failed (storeLoadPreconditions (rewriter, readOp, vecTy)))
451
+ return failure ();
452
+
168
453
bool isOutOfBounds = readOp.hasOutOfBoundsDim ();
169
454
if (isOutOfBounds && !isZeroConstant (readOp.getPadding ()))
170
455
return rewriter.notifyMatchFailure (
@@ -173,7 +458,6 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
173
458
AffineMap readMap = readOp.getPermutationMap ();
174
459
bool isTransposeLoad = !readMap.isMinorIdentity ();
175
460
176
- VectorType vecTy = readOp.getVectorType ();
177
461
Type elementType = vecTy.getElementType ();
178
462
unsigned minTransposeBitWidth = 32 ;
179
463
if (isTransposeLoad &&
@@ -221,11 +505,26 @@ struct TransferWriteLowering
221
505
if (failed (transferPreconditions (rewriter, writeOp)))
222
506
return failure ();
223
507
508
+ // TODO:This check needs to be replaced with proper uArch capability check
509
+ auto chip = xegpu::getChipStr (writeOp);
510
+ if (chip != " pvc" && chip != " bmg" ) {
511
+ // lower to scattered store Op if the target HW doesn't have 2d block
512
+ // store support
513
+ // TODO: add support for OutOfBound access
514
+ if (writeOp.hasOutOfBoundsDim ())
515
+ return failure ();
516
+ return lowerToScatteredStoreOp (writeOp, rewriter);
517
+ }
518
+
519
+ // Perform common data transfer checks.
520
+ VectorType vecTy = writeOp.getVectorType ();
521
+ if (failed (storeLoadPreconditions (rewriter, writeOp, vecTy)))
522
+ return failure ();
523
+
224
524
AffineMap map = writeOp.getPermutationMap ();
225
525
if (!map.isMinorIdentity ())
226
526
return rewriter.notifyMatchFailure (writeOp, " Expects identity map" );
227
527
228
- VectorType vecTy = writeOp.getVectorType ();
229
528
auto descType = xegpu::TensorDescType::get (
230
529
vecTy.getShape (), vecTy.getElementType (),
231
530
/* array_length=*/ 1 , /* boundary_check=*/ writeOp.hasOutOfBoundsDim (),
0 commit comments