-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR][Conversion] Convert XeGPU to XeVM pass: Remove lowering support for tensor descriptor with offsets. #157550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -154,6 +154,9 @@ class CreateNdDescToXeVMPattern | |
| matchAndRewrite(xegpu::CreateNdDescOp op, | ||
| xegpu::CreateNdDescOp::Adaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets(); | ||
| if (mixedOffsets.size() != 0) | ||
| return rewriter.notifyMatchFailure(op, "Offsets not supported."); | ||
| auto loc = op.getLoc(); | ||
| auto source = op.getSource(); | ||
| // Op is lowered to a code sequence that populates payload. | ||
|
|
@@ -177,7 +180,6 @@ class CreateNdDescToXeVMPattern | |
|
|
||
| // Source can be a memref or a pointer (ui64, ui32, i64 or i32). | ||
| SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes(); | ||
| SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets(); | ||
| // Descriptor shape is expected to be 2D. | ||
| int64_t rank = mixedSizes.size(); | ||
| if (rank != 2) | ||
|
|
@@ -202,17 +204,9 @@ class CreateNdDescToXeVMPattern | |
| val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val); | ||
| return val; | ||
| }; | ||
| // Offsets can be either 2D or not provided (0 is used). | ||
| if (mixedOffsets.size() == 2) { | ||
| offsetW = createOffset(mixedOffsets, 1); | ||
| offsetH = createOffset(mixedOffsets, 0); | ||
| } else if (mixedOffsets.size() == 0) { | ||
| offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); | ||
| offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); | ||
| } else { | ||
| return rewriter.notifyMatchFailure(op, | ||
| "Expected 2D offsets or no offsets."); | ||
| } | ||
| // Offsets are not supported not (0 is used). | ||
| offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); | ||
| offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); | ||
| // Get shape values from op fold results. | ||
| baseShapeW = createOffset(mixedSizes, 1); | ||
| baseShapeH = createOffset(mixedSizes, 0); | ||
|
|
@@ -247,39 +241,6 @@ class CreateNdDescToXeVMPattern | |
| } | ||
| }; | ||
|
|
||
| class UpdateNdOffsetToXeVMPattern | ||
| : public OpConversionPattern<xegpu::UpdateNdOffsetOp> { | ||
| using OpConversionPattern::OpConversionPattern; | ||
| LogicalResult | ||
| matchAndRewrite(xegpu::UpdateNdOffsetOp op, | ||
| xegpu::UpdateNdOffsetOp::Adaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| auto loc = op.getLoc(); | ||
| auto mixedOffsets = op.getMixedOffsets(); | ||
| // Only 2D offsets are supported for now. | ||
| if (mixedOffsets.size() != 2) | ||
| return rewriter.notifyMatchFailure(op, "Expected 2D offsets."); | ||
| auto payload = adaptor.getTensorDesc(); | ||
| // Utility for updating payload offset values from op fold result. | ||
| auto updateOffset = [&](unsigned idx, int payloadPos) -> Value { | ||
| Value offset = | ||
| getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]); | ||
| offset = getValueOrCreateCastToIndexLike(rewriter, loc, | ||
| rewriter.getI32Type(), offset); | ||
| Value oldOffset = | ||
| vector::ExtractOp::create(rewriter, loc, payload, payloadPos); | ||
| Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset); | ||
| return vector::InsertOp::create(rewriter, loc, newOffset, payload, | ||
| payloadPos); | ||
| }; | ||
| // Update offsets in the payload. | ||
| payload = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH)); | ||
| payload = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW)); | ||
| rewriter.replaceOp(op, payload); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| template < | ||
| typename OpType, | ||
| typename = std::enable_if_t<llvm::is_one_of< | ||
|
|
@@ -289,6 +250,10 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> { | |
| LogicalResult | ||
| matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| auto mixedOffsets = op.getMixedOffsets(); | ||
| int64_t opOffsetsSize = mixedOffsets.size(); | ||
| if (opOffsetsSize != 2) | ||
| return rewriter.notifyMatchFailure(op, "Expected 2D offsets."); | ||
| auto loc = op.getLoc(); | ||
| auto ctxt = rewriter.getContext(); | ||
|
|
||
|
|
@@ -311,32 +276,16 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> { | |
| rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW)); | ||
| Value baseShapeH = vector::ExtractOp::create( | ||
| rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH)); | ||
| // Offsets provided in two ways: | ||
| // 1. Offsets are extracted from the tensor descriptor. | ||
| // 2. (Mixed) offsets which are provided by the op. | ||
| Value offsetW; | ||
| Value offsetH; | ||
| auto mixedOffsets = op.getMixedOffsets(); | ||
| int64_t opOffsetsSize = mixedOffsets.size(); | ||
| if (opOffsetsSize != 0 && opOffsetsSize != 2) | ||
| return rewriter.notifyMatchFailure(op, | ||
| "Expected 2D offsets or no offsets."); | ||
| if (opOffsetsSize) { | ||
| // If mixed offsets are provided by the op convert them to i32. | ||
| offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); | ||
| offsetW = getValueOrCreateCastToIndexLike(rewriter, loc, | ||
| rewriter.getI32Type(), offsetW); | ||
| offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); | ||
| offsetH = getValueOrCreateCastToIndexLike(rewriter, loc, | ||
| rewriter.getI32Type(), offsetH); | ||
| } else { | ||
| // If offsets are not available, we need to extract them from the tensor | ||
| // descriptor. | ||
| offsetW = vector::ExtractOp::create( | ||
| rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetW)); | ||
| offsetH = vector::ExtractOp::create( | ||
| rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetH)); | ||
| } | ||
| // Offsets are provided by the op. | ||
| // convert them to i32. | ||
| Value offsetW = | ||
| getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); | ||
| offsetW = getValueOrCreateCastToIndexLike(rewriter, loc, | ||
| rewriter.getI32Type(), offsetW); | ||
| Value offsetH = | ||
| getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); | ||
| offsetH = getValueOrCreateCastToIndexLike(rewriter, loc, | ||
| rewriter.getI32Type(), offsetH); | ||
| // Get address space from tensor descriptor memory space. | ||
| auto ptrTypeLLVM = LLVM::LLVMPointerType::get( | ||
| ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); | ||
|
|
@@ -422,54 +371,6 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc, | |
| return newAddr; | ||
| } | ||
|
|
||
| class CreateDescToXeVMPattern | ||
| : public OpConversionPattern<xegpu::CreateDescOp> { | ||
| using OpConversionPattern::OpConversionPattern; | ||
| LogicalResult | ||
| matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| auto eTy = op.getTensorDescType().getElementType(); | ||
| auto eBw = eTy.getIntOrFloatBitWidth(); | ||
| if (eBw % 8 != 0) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Expected element type bit width to be multiple of 8."); | ||
| auto loc = op.getLoc(); | ||
| // Offsets are provided as scalar i64 by type converter. | ||
| auto offsets = adaptor.getOffsets(); | ||
| // Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32). | ||
| // But type converter will convert them to integer types. | ||
| Value addr = adaptor.getSource(); | ||
| // ui32 or i32 are passed as i32 so they need to be casted to i64. | ||
| if (addr.getType() != rewriter.getI64Type()) | ||
| addr = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), addr); | ||
| auto laneAddr = addOffset(rewriter, loc, addr, offsets, eBw / 8); | ||
| rewriter.replaceOp(op, laneAddr); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| class UpdateOffsetToXeVMPattern | ||
| : public OpConversionPattern<xegpu::UpdateOffsetOp> { | ||
| using OpConversionPattern::OpConversionPattern; | ||
| LogicalResult | ||
| matchAndRewrite(xegpu::UpdateOffsetOp op, | ||
| xegpu::UpdateOffsetOp::Adaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| auto eTy = op.getTensorDescType().getElementType(); | ||
| auto eBw = eTy.getIntOrFloatBitWidth(); | ||
| if (eBw % 8 != 0) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Expected element type bit width to be multiple of 8."); | ||
| auto loc = op.getLoc(); | ||
| // Scatter descriptor is provided as scalar i64 by type converter. | ||
| // Offsets are provided as scalar i64 by type converter. | ||
| Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(), | ||
| adaptor.getOffsets(), eBw / 8); | ||
| rewriter.replaceOp(op, newOffset); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| template <typename OpType, | ||
| typename = std::enable_if_t<llvm::is_one_of< | ||
| OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>> | ||
|
|
@@ -478,6 +379,10 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { | |
| LogicalResult | ||
| matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| Value offsets = adaptor.getOffsets(); | ||
|
||
| if (!offsets) | ||
| return rewriter.notifyMatchFailure(op, | ||
| "Expected offsets to be provided."); | ||
| auto loc = op.getLoc(); | ||
| auto ctxt = rewriter.getContext(); | ||
| auto tdescTy = op.getTensorDescType(); | ||
|
|
@@ -527,21 +432,17 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { | |
| basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), | ||
| basePtrI64); | ||
| } | ||
| Value offsets = adaptor.getOffsets(); | ||
| Value mask = adaptor.getMask(); | ||
| if (offsets) { | ||
| if (dyn_cast<VectorType>(offsets.getType())) { | ||
| // Offset needs be scalar. Single element vector is converted to scalar | ||
| // by type converter. | ||
| return rewriter.notifyMatchFailure(op, | ||
| "Expected offsets to be a scalar."); | ||
| } else { | ||
| // If offsets are provided, we add them to the base pointer. | ||
| // Offsets are in number of elements, we need to multiply by | ||
| // element byte size. | ||
| basePtrI64 = | ||
| addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize); | ||
| } | ||
| if (dyn_cast<VectorType>(offsets.getType())) { | ||
| // Offset needs be scalar. Single element vector is converted to scalar | ||
| // by type converter. | ||
| return rewriter.notifyMatchFailure(op, | ||
| "Expected offsets to be a scalar."); | ||
| } else { | ||
| // If offsets are provided, we add them to the base pointer. | ||
| // Offsets are in number of elements, we need to multiply by | ||
| // element byte size. | ||
| basePtrI64 = addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize); | ||
| } | ||
| // Convert base pointer (i64) to LLVM pointer type. | ||
| Value basePtrLLVM = | ||
|
|
@@ -1011,13 +912,12 @@ struct ConvertXeGPUToXeVMPass | |
| //===----------------------------------------------------------------------===// | ||
| void mlir::populateXeGPUToXeVMConversionPatterns( | ||
| const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { | ||
| patterns.add<CreateNdDescToXeVMPattern, UpdateNdOffsetToXeVMPattern, | ||
| patterns.add<CreateNdDescToXeVMPattern, | ||
| LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>, | ||
| LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>, | ||
| LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>( | ||
| typeConverter, patterns.getContext()); | ||
| patterns.add<CreateDescToXeVMPattern, UpdateOffsetToXeVMPattern, | ||
| AtomicRMWToXeVMPattern, PrefetchToXeVMPattern, | ||
| patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern, | ||
| LoadStoreToXeVMPattern<xegpu::LoadGatherOp>, | ||
| LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>( | ||
| typeConverter, patterns.getContext()); | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: drop the second "not" in "Offsets are not supported not"