@@ -154,6 +154,9 @@ class CreateNdDescToXeVMPattern
154154 matchAndRewrite(xegpu::CreateNdDescOp op,
155155 xegpu::CreateNdDescOp::Adaptor adaptor,
156156 ConversionPatternRewriter &rewriter) const override {
157+ SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
158+ if (mixedOffsets.size() != 0)
159+ return rewriter.notifyMatchFailure(op, "Offsets not supported.");
157160 auto loc = op.getLoc();
158161 auto source = op.getSource();
159162 // Op is lowered to a code sequence that populates payload.
@@ -177,7 +180,6 @@ class CreateNdDescToXeVMPattern
177180
178181 // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
179182 SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
180- SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
181183 // Descriptor shape is expected to be 2D.
182184 int64_t rank = mixedSizes.size();
183185 if (rank != 2)
@@ -202,17 +204,9 @@ class CreateNdDescToXeVMPattern
202204 val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
203205 return val;
204206 };
205- // Offsets can be either 2D or not provided (0 is used).
206- if (mixedOffsets.size() == 2) {
207- offsetW = createOffset(mixedOffsets, 1);
208- offsetH = createOffset(mixedOffsets, 0);
209- } else if (mixedOffsets.size() == 0) {
210- offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
211- offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
212- } else {
213- return rewriter.notifyMatchFailure(op,
214- "Expected 2D offsets or no offsets.");
215- }
207+ // Offsets are not supported not (0 is used).
208+ offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
209+ offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
216210 // Get shape values from op fold results.
217211 baseShapeW = createOffset(mixedSizes, 1);
218212 baseShapeH = createOffset(mixedSizes, 0);
@@ -247,39 +241,6 @@ class CreateNdDescToXeVMPattern
247241 }
248242};
249243
250- class UpdateNdOffsetToXeVMPattern
251- : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
252- using OpConversionPattern::OpConversionPattern;
253- LogicalResult
254- matchAndRewrite(xegpu::UpdateNdOffsetOp op,
255- xegpu::UpdateNdOffsetOp::Adaptor adaptor,
256- ConversionPatternRewriter &rewriter) const override {
257- auto loc = op.getLoc();
258- auto mixedOffsets = op.getMixedOffsets();
259- // Only 2D offsets are supported for now.
260- if (mixedOffsets.size() != 2)
261- return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
262- auto payload = adaptor.getTensorDesc();
263- // Utility for updating payload offset values from op fold result.
264- auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
265- Value offset =
266- getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]);
267- offset = getValueOrCreateCastToIndexLike(rewriter, loc,
268- rewriter.getI32Type(), offset);
269- Value oldOffset =
270- vector::ExtractOp::create(rewriter, loc, payload, payloadPos);
271- Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
272- return vector::InsertOp::create(rewriter, loc, newOffset, payload,
273- payloadPos);
274- };
275- // Update offsets in the payload.
276- payload = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
277- payload = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
278- rewriter.replaceOp(op, payload);
279- return success();
280- }
281- };
282-
283244template <
284245 typename OpType,
285246 typename = std::enable_if_t<llvm::is_one_of<
@@ -289,6 +250,10 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
289250 LogicalResult
290251 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
291252 ConversionPatternRewriter &rewriter) const override {
253+ auto mixedOffsets = op.getMixedOffsets();
254+ int64_t opOffsetsSize = mixedOffsets.size();
255+ if (opOffsetsSize != 2)
256+ return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
292257 auto loc = op.getLoc();
293258 auto ctxt = rewriter.getContext();
294259
@@ -311,32 +276,14 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
311276 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
312277 Value baseShapeH = vector::ExtractOp::create(
313278 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
314- // Offsets provided in two ways:
315- // 1. Offsets are extracted from the tensor descriptor.
316- // 2. (Mixed) offsets which are provided by the op.
317- Value offsetW;
318- Value offsetH;
319- auto mixedOffsets = op.getMixedOffsets();
320- int64_t opOffsetsSize = mixedOffsets.size();
321- if (opOffsetsSize != 0 && opOffsetsSize != 2)
322- return rewriter.notifyMatchFailure(op,
323- "Expected 2D offsets or no offsets.");
324- if (opOffsetsSize) {
325- // If mixed offsets are provided by the op convert them to i32.
326- offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
327- offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
328- rewriter.getI32Type(), offsetW);
329- offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
330- offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
331- rewriter.getI32Type(), offsetH);
332- } else {
333- // If offsets are not available, we need to extract them from the tensor
334- // descriptor.
335- offsetW = vector::ExtractOp::create(
336- rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetW));
337- offsetH = vector::ExtractOp::create(
338- rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetH));
339- }
279+ // Offsets are provided by the op.
280+ // convert them to i32.
281+ Value offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
282+ offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
283+ rewriter.getI32Type(), offsetW);
284+ Value offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
285+ offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
286+ rewriter.getI32Type(), offsetH);
340287 // Get address space from tensor descriptor memory space.
341288 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
342289 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
@@ -422,54 +369,6 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
422369 return newAddr;
423370}
424371
425- class CreateDescToXeVMPattern
426- : public OpConversionPattern<xegpu::CreateDescOp> {
427- using OpConversionPattern::OpConversionPattern;
428- LogicalResult
429- matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
430- ConversionPatternRewriter &rewriter) const override {
431- auto eTy = op.getTensorDescType().getElementType();
432- auto eBw = eTy.getIntOrFloatBitWidth();
433- if (eBw % 8 != 0)
434- return rewriter.notifyMatchFailure(
435- op, "Expected element type bit width to be multiple of 8.");
436- auto loc = op.getLoc();
437- // Offsets are provided as scalar i64 by type converter.
438- auto offsets = adaptor.getOffsets();
439- // Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32).
440- // But type converter will convert them to integer types.
441- Value addr = adaptor.getSource();
442- // ui32 or i32 are passed as i32 so they need to be casted to i64.
443- if (addr.getType() != rewriter.getI64Type())
444- addr = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), addr);
445- auto laneAddr = addOffset(rewriter, loc, addr, offsets, eBw / 8);
446- rewriter.replaceOp(op, laneAddr);
447- return success();
448- }
449- };
450-
451- class UpdateOffsetToXeVMPattern
452- : public OpConversionPattern<xegpu::UpdateOffsetOp> {
453- using OpConversionPattern::OpConversionPattern;
454- LogicalResult
455- matchAndRewrite(xegpu::UpdateOffsetOp op,
456- xegpu::UpdateOffsetOp::Adaptor adaptor,
457- ConversionPatternRewriter &rewriter) const override {
458- auto eTy = op.getTensorDescType().getElementType();
459- auto eBw = eTy.getIntOrFloatBitWidth();
460- if (eBw % 8 != 0)
461- return rewriter.notifyMatchFailure(
462- op, "Expected element type bit width to be multiple of 8.");
463- auto loc = op.getLoc();
464- // Scatter descriptor is provided as scalar i64 by type converter.
465- // Offsets are provided as scalar i64 by type converter.
466- Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(),
467- adaptor.getOffsets(), eBw / 8);
468- rewriter.replaceOp(op, newOffset);
469- return success();
470- }
471- };
472-
473372template <typename OpType,
474373 typename = std::enable_if_t<llvm::is_one_of<
475374 OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
@@ -478,6 +377,9 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
478377 LogicalResult
479378 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
480379 ConversionPatternRewriter &rewriter) const override {
380+ Value offsets = adaptor.getOffsets();
381+ if (!offsets)
382+ return rewriter.notifyMatchFailure(op, "Expected offsets to be provided.");
481383 auto loc = op.getLoc();
482384 auto ctxt = rewriter.getContext();
483385 auto tdescTy = op.getTensorDescType();
@@ -527,21 +429,18 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
527429 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
528430 basePtrI64);
529431 }
530- Value offsets = adaptor.getOffsets();
531432 Value mask = adaptor.getMask();
532- if (offsets) {
533- if (dyn_cast<VectorType>(offsets.getType())) {
534- // Offset needs be scalar. Single element vector is converted to scalar
535- // by type converter.
536- return rewriter.notifyMatchFailure(op,
537- "Expected offsets to be a scalar.");
538- } else {
539- // If offsets are provided, we add them to the base pointer.
540- // Offsets are in number of elements, we need to multiply by
541- // element byte size.
542- basePtrI64 =
543- addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
544- }
433+ if (dyn_cast<VectorType>(offsets.getType())) {
434+ // Offset needs be scalar. Single element vector is converted to scalar
435+ // by type converter.
436+ return rewriter.notifyMatchFailure(op,
437+ "Expected offsets to be a scalar.");
438+ } else {
439+ // If offsets are provided, we add them to the base pointer.
440+ // Offsets are in number of elements, we need to multiply by
441+ // element byte size.
442+ basePtrI64 =
443+ addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
545444 }
546445 // Convert base pointer (i64) to LLVM pointer type.
547446 Value basePtrLLVM =
@@ -1011,13 +910,12 @@ struct ConvertXeGPUToXeVMPass
1011910//===----------------------------------------------------------------------===//
1012911void mlir::populateXeGPUToXeVMConversionPatterns(
1013912 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1014- patterns.add<CreateNdDescToXeVMPattern, UpdateNdOffsetToXeVMPattern,
913+ patterns.add<CreateNdDescToXeVMPattern,
1015914 LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
1016915 LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
1017916 LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
1018917 typeConverter, patterns.getContext());
1019- patterns.add<CreateDescToXeVMPattern, UpdateOffsetToXeVMPattern,
1020- AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
918+ patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
1021919 LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
1022920 LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
1023921 typeConverter, patterns.getContext());
0 commit comments