@@ -154,6 +154,9 @@ class CreateNdDescToXeVMPattern
154
154
matchAndRewrite (xegpu::CreateNdDescOp op,
155
155
xegpu::CreateNdDescOp::Adaptor adaptor,
156
156
ConversionPatternRewriter &rewriter) const override {
157
+ SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets ();
158
+ if (mixedOffsets.size () != 0 )
159
+ return rewriter.notifyMatchFailure (op, " Offsets not supported." );
157
160
auto loc = op.getLoc ();
158
161
auto source = op.getSource ();
159
162
// Op is lowered to a code sequence that populates payload.
@@ -177,7 +180,6 @@ class CreateNdDescToXeVMPattern
177
180
178
181
// Source can be a memref or a pointer (ui64, ui32, i64 or i32).
179
182
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes ();
180
- SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets ();
181
183
// Descriptor shape is expected to be 2D.
182
184
int64_t rank = mixedSizes.size ();
183
185
if (rank != 2 )
@@ -202,17 +204,9 @@ class CreateNdDescToXeVMPattern
202
204
val = getValueOrCreateCastToIndexLike (rewriter, loc, payloadElemTy, val);
203
205
return val;
204
206
};
205
- // Offsets can be either 2D or not provided (0 is used).
206
- if (mixedOffsets.size () == 2 ) {
207
- offsetW = createOffset (mixedOffsets, 1 );
208
- offsetH = createOffset (mixedOffsets, 0 );
209
- } else if (mixedOffsets.size () == 0 ) {
210
- offsetW = arith::ConstantIntOp::create (rewriter, loc, payloadElemTy, 0 );
211
- offsetH = arith::ConstantIntOp::create (rewriter, loc, payloadElemTy, 0 );
212
- } else {
213
- return rewriter.notifyMatchFailure (op,
214
- " Expected 2D offsets or no offsets." );
215
- }
207
+ // Offsets are not supported (0 is used).
208
+ offsetW = arith::ConstantIntOp::create (rewriter, loc, payloadElemTy, 0 );
209
+ offsetH = arith::ConstantIntOp::create (rewriter, loc, payloadElemTy, 0 );
216
210
// Get shape values from op fold results.
217
211
baseShapeW = createOffset (mixedSizes, 1 );
218
212
baseShapeH = createOffset (mixedSizes, 0 );
@@ -247,39 +241,6 @@ class CreateNdDescToXeVMPattern
247
241
}
248
242
};
249
243
250
- class UpdateNdOffsetToXeVMPattern
251
- : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
252
- using OpConversionPattern::OpConversionPattern;
253
- LogicalResult
254
- matchAndRewrite (xegpu::UpdateNdOffsetOp op,
255
- xegpu::UpdateNdOffsetOp::Adaptor adaptor,
256
- ConversionPatternRewriter &rewriter) const override {
257
- auto loc = op.getLoc ();
258
- auto mixedOffsets = op.getMixedOffsets ();
259
- // Only 2D offsets are supported for now.
260
- if (mixedOffsets.size () != 2 )
261
- return rewriter.notifyMatchFailure (op, " Expected 2D offsets." );
262
- auto payload = adaptor.getTensorDesc ();
263
- // Utility for updating payload offset values from op fold result.
264
- auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
265
- Value offset =
266
- getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[idx]);
267
- offset = getValueOrCreateCastToIndexLike (rewriter, loc,
268
- rewriter.getI32Type (), offset);
269
- Value oldOffset =
270
- vector::ExtractOp::create (rewriter, loc, payload, payloadPos);
271
- Value newOffset = arith::AddIOp::create (rewriter, loc, oldOffset, offset);
272
- return vector::InsertOp::create (rewriter, loc, newOffset, payload,
273
- payloadPos);
274
- };
275
- // Update offsets in the payload.
276
- payload = updateOffset (0 , static_cast <int >(NdTdescOffset::TensorOffsetH));
277
- payload = updateOffset (1 , static_cast <int >(NdTdescOffset::TensorOffsetW));
278
- rewriter.replaceOp (op, payload);
279
- return success ();
280
- }
281
- };
282
-
283
244
template <
284
245
typename OpType,
285
246
typename = std::enable_if_t <llvm::is_one_of<
@@ -289,6 +250,10 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
289
250
LogicalResult
290
251
matchAndRewrite (OpType op, typename OpType::Adaptor adaptor,
291
252
ConversionPatternRewriter &rewriter) const override {
253
+ auto mixedOffsets = op.getMixedOffsets ();
254
+ int64_t opOffsetsSize = mixedOffsets.size ();
255
+ if (opOffsetsSize != 2 )
256
+ return rewriter.notifyMatchFailure (op, " Expected 2D offsets." );
292
257
auto loc = op.getLoc ();
293
258
auto ctxt = rewriter.getContext ();
294
259
@@ -311,32 +276,16 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
311
276
rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::BaseShapeW));
312
277
Value baseShapeH = vector::ExtractOp::create (
313
278
rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::BaseShapeH));
314
- // Offsets provided in two ways:
315
- // 1. Offsets are extracted from the tensor descriptor.
316
- // 2. (Mixed) offsets which are provided by the op.
317
- Value offsetW;
318
- Value offsetH;
319
- auto mixedOffsets = op.getMixedOffsets ();
320
- int64_t opOffsetsSize = mixedOffsets.size ();
321
- if (opOffsetsSize != 0 && opOffsetsSize != 2 )
322
- return rewriter.notifyMatchFailure (op,
323
- " Expected 2D offsets or no offsets." );
324
- if (opOffsetsSize) {
325
- // If mixed offsets are provided by the op convert them to i32.
326
- offsetW = getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[1 ]);
327
- offsetW = getValueOrCreateCastToIndexLike (rewriter, loc,
328
- rewriter.getI32Type (), offsetW);
329
- offsetH = getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[0 ]);
330
- offsetH = getValueOrCreateCastToIndexLike (rewriter, loc,
331
- rewriter.getI32Type (), offsetH);
332
- } else {
333
- // If offsets are not available, we need to extract them from the tensor
334
- // descriptor.
335
- offsetW = vector::ExtractOp::create (
336
- rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::TensorOffsetW));
337
- offsetH = vector::ExtractOp::create (
338
- rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::TensorOffsetH));
339
- }
279
+ // Offsets are provided by the op.
280
+ // convert them to i32.
281
+ Value offsetW =
282
+ getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[1 ]);
283
+ offsetW = getValueOrCreateCastToIndexLike (rewriter, loc,
284
+ rewriter.getI32Type (), offsetW);
285
+ Value offsetH =
286
+ getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[0 ]);
287
+ offsetH = getValueOrCreateCastToIndexLike (rewriter, loc,
288
+ rewriter.getI32Type (), offsetH);
340
289
// Get address space from tensor descriptor memory space.
341
290
auto ptrTypeLLVM = LLVM::LLVMPointerType::get (
342
291
ctxt, getNumericXeVMAddrSpace (tdescTy.getMemorySpace ()));
@@ -422,54 +371,6 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
422
371
return newAddr;
423
372
}
424
373
425
- class CreateDescToXeVMPattern
426
- : public OpConversionPattern<xegpu::CreateDescOp> {
427
- using OpConversionPattern::OpConversionPattern;
428
- LogicalResult
429
- matchAndRewrite (xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
430
- ConversionPatternRewriter &rewriter) const override {
431
- auto eTy = op.getTensorDescType ().getElementType ();
432
- auto eBw = eTy.getIntOrFloatBitWidth ();
433
- if (eBw % 8 != 0 )
434
- return rewriter.notifyMatchFailure (
435
- op, " Expected element type bit width to be multiple of 8." );
436
- auto loc = op.getLoc ();
437
- // Offsets are provided as scalar i64 by type converter.
438
- auto offsets = adaptor.getOffsets ();
439
- // Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32).
440
- // But type converter will convert them to integer types.
441
- Value addr = adaptor.getSource ();
442
- // ui32 or i32 are passed as i32 so they need to be casted to i64.
443
- if (addr.getType () != rewriter.getI64Type ())
444
- addr = arith::ExtUIOp::create (rewriter, loc, rewriter.getI64Type (), addr);
445
- auto laneAddr = addOffset (rewriter, loc, addr, offsets, eBw / 8 );
446
- rewriter.replaceOp (op, laneAddr);
447
- return success ();
448
- }
449
- };
450
-
451
- class UpdateOffsetToXeVMPattern
452
- : public OpConversionPattern<xegpu::UpdateOffsetOp> {
453
- using OpConversionPattern::OpConversionPattern;
454
- LogicalResult
455
- matchAndRewrite (xegpu::UpdateOffsetOp op,
456
- xegpu::UpdateOffsetOp::Adaptor adaptor,
457
- ConversionPatternRewriter &rewriter) const override {
458
- auto eTy = op.getTensorDescType ().getElementType ();
459
- auto eBw = eTy.getIntOrFloatBitWidth ();
460
- if (eBw % 8 != 0 )
461
- return rewriter.notifyMatchFailure (
462
- op, " Expected element type bit width to be multiple of 8." );
463
- auto loc = op.getLoc ();
464
- // Scatter descriptor is provided as scalar i64 by type converter.
465
- // Offsets are provided as scalar i64 by type converter.
466
- Value newOffset = addOffset (rewriter, loc, adaptor.getTensorDesc (),
467
- adaptor.getOffsets (), eBw / 8 );
468
- rewriter.replaceOp (op, newOffset);
469
- return success ();
470
- }
471
- };
472
-
473
374
template <typename OpType,
474
375
typename = std::enable_if_t <llvm::is_one_of<
475
376
OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
@@ -478,6 +379,9 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
478
379
LogicalResult
479
380
matchAndRewrite (OpType op, typename OpType::Adaptor adaptor,
480
381
ConversionPatternRewriter &rewriter) const override {
382
+ Value offset = adaptor.getOffsets ();
383
+ if (!offset)
384
+ return rewriter.notifyMatchFailure (op, " Expected offset to be provided." );
481
385
auto loc = op.getLoc ();
482
386
auto ctxt = rewriter.getContext ();
483
387
auto tdescTy = op.getTensorDescType ();
@@ -527,21 +431,16 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
527
431
basePtrI64 = arith::ExtUIOp::create (rewriter, loc, rewriter.getI64Type (),
528
432
basePtrI64);
529
433
}
530
- Value offsets = adaptor.getOffsets ();
531
434
Value mask = adaptor.getMask ();
532
- if (offsets) {
533
- if (dyn_cast<VectorType>(offsets.getType ())) {
534
- // Offset needs be scalar. Single element vector is converted to scalar
535
- // by type converter.
536
- return rewriter.notifyMatchFailure (op,
537
- " Expected offsets to be a scalar." );
538
- } else {
539
- // If offsets are provided, we add them to the base pointer.
540
- // Offsets are in number of elements, we need to multiply by
541
- // element byte size.
542
- basePtrI64 =
543
- addOffset (rewriter, loc, basePtrI64, offsets, elemByteSize);
544
- }
435
+ if (dyn_cast<VectorType>(offset.getType ())) {
436
+ // Offset needs be scalar. Single element vector is converted to scalar
437
+ // by type converter.
438
+ return rewriter.notifyMatchFailure (op, " Expected offset to be a scalar." );
439
+ } else {
440
+ // If offset is provided, we add them to the base pointer.
441
+ // Offset is in number of elements, we need to multiply by
442
+ // element byte size.
443
+ basePtrI64 = addOffset (rewriter, loc, basePtrI64, offset, elemByteSize);
545
444
}
546
445
// Convert base pointer (i64) to LLVM pointer type.
547
446
Value basePtrLLVM =
@@ -1011,13 +910,12 @@ struct ConvertXeGPUToXeVMPass
1011
910
// ===----------------------------------------------------------------------===//
1012
911
void mlir::populateXeGPUToXeVMConversionPatterns (
1013
912
const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1014
- patterns.add <CreateNdDescToXeVMPattern, UpdateNdOffsetToXeVMPattern,
913
+ patterns.add <CreateNdDescToXeVMPattern,
1015
914
LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
1016
915
LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
1017
916
LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
1018
917
typeConverter, patterns.getContext ());
1019
- patterns.add <CreateDescToXeVMPattern, UpdateOffsetToXeVMPattern,
1020
- AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
918
+ patterns.add <AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
1021
919
LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
1022
920
LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
1023
921
typeConverter, patterns.getContext ());
0 commit comments