-
Couldn't load subscription status.
- Fork 15k
[MLIR][XeGPU][VectorToXeGPU] Lower vector.load/store/transfer_read/transfer_write to new offsets syntax #162095
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
e581a0b
b56c1cd
8581183
e04202b
37e1843
614887b
beeac48
0ef9ed7
a6ea6f3
392a01f
e334d0b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -121,6 +121,22 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy, | |
| return success(); | ||
| } | ||
|
|
||
| // Verify that number of offsets matches either the source rank or the tdesc | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should simply check that offsets can't have lower rank than tdesc rank. It shoud be fine if it is larger than tdesc rank. I am not sure that it is common practice for the op validation to validate itself using information from producer op. I would rather checking this in transformation or lowering passes. Any opinion from @adam-smnk ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Op verifier should be contained to the op itself without accessing any external data. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Sound reasonable. Just please clearly define how offsets are interpreted. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. simplified the check to not rely on Also added docs for the new offset syntax |
||
| // rank. | ||
| static LogicalResult | ||
| isValidNdOffset(TypedValue<TensorDescType> tDesc, | ||
| std::optional<llvm::ArrayRef<int64_t>> constOffsets, | ||
| int64_t offsetSize, | ||
| function_ref<InFlightDiagnostic()> emitError) { | ||
| int64_t constOffsetSize = constOffsets ? constOffsets->size() : 0; | ||
| auto tDescRank = tDesc.getType().getRank(); | ||
| if (((offsetSize != 0) && (offsetSize < tDescRank)) || | ||
| ((constOffsetSize != 0) && (constOffsetSize < tDescRank))) | ||
| return emitError() << "Offsets rank cannot be smaller than tensor " | ||
| "descriptor rank."; | ||
| return success(); | ||
| } | ||
|
|
||
| static LogicalResult | ||
| isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, | ||
| VectorType valueTy, int64_t chunkSize, | ||
|
|
@@ -215,8 +231,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, | |
| auto [memrefStrides, _] = memrefTy.getStridesAndOffset(); | ||
|
|
||
| // if shape and strides are from Memref, we don't need attributes for them | ||
| // to keep the IR print clean. | ||
| if (staticShape == memrefShape && staticStrides == memrefStrides) { | ||
| // to keep the IR print clean (only do so for full-static case, otherwise | ||
| // printer would fail trying to print empty array-attr). | ||
| if (staticShape == memrefShape && staticStrides == memrefStrides && | ||
| dynamicShape.empty() && dynamicStrides.empty()) { | ||
| staticShapeAttr = DenseI64ArrayAttr(); | ||
| staticStridesAttr = DenseI64ArrayAttr(); | ||
| } | ||
|
|
@@ -277,8 +295,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, | |
| auto [memrefStrides, _] = memrefTy.getStridesAndOffset(); | ||
|
|
||
| // if shape and strides are from Memref, we don't need attributes for them | ||
| // to keep the IR print clean. | ||
| if (staticShape == memrefShape && staticStrides == memrefStrides) { | ||
| // to keep the IR print clean (only do so for full-static case, otherwise | ||
| // printer would fail trying to print empty array-attr). | ||
| if (staticShape == memrefShape && staticStrides == memrefStrides && | ||
| dynamicShape.empty() && dynamicStrides.empty()) { | ||
| staticShapeAttr = DenseI64ArrayAttr(); | ||
| staticStridesAttr = DenseI64ArrayAttr(); | ||
| } | ||
|
|
@@ -428,16 +448,9 @@ LogicalResult PrefetchNdOp::verify() { | |
| if (!isReadHintOrNone(getL3HintAttr())) | ||
| return emitOpError("invalid l3_hint: ") << getL3HintAttr(); | ||
|
|
||
| int64_t tDescRank = tdescTy.getRank(); | ||
| int64_t offsetSize = static_cast<int64_t>(getOffsets().size()); | ||
| int64_t constOffsetSize = | ||
| getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0; | ||
| if (((offsetSize != 0) && (offsetSize != tDescRank)) || | ||
dchigarev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ((constOffsetSize != 0) && (constOffsetSize != tDescRank))) | ||
| return emitOpError( | ||
| "Mismatched ranks between offsets and tensor descriptor"); | ||
|
|
||
| return success(); | ||
| auto tDesc = getTensorDesc(); | ||
| return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(), | ||
| [&]() { return emitOpError(); }); | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
|
|
@@ -553,16 +566,9 @@ LogicalResult LoadNdOp::verify() { | |
| << " is not consistent with tensor descriptor " | ||
| << tdescTy; | ||
|
|
||
| int64_t tDescRank = tdescTy.getRank(); | ||
| int64_t offsetSize = static_cast<int64_t>(getOffsets().size()); | ||
| int64_t constOffsetSize = | ||
| getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0; | ||
| if (((offsetSize != 0) && (offsetSize != tDescRank)) || | ||
| ((constOffsetSize != 0) && (constOffsetSize != tDescRank))) | ||
| return emitOpError( | ||
| "Mismatched ranks between offsets and tensor descriptor"); | ||
|
|
||
| return success(); | ||
| auto tDesc = getTensorDesc(); | ||
| return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(), | ||
| [&]() { return emitOpError(); }); | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
|
|
@@ -647,16 +653,9 @@ LogicalResult StoreNdOp::verify() { | |
| << " is not consistent with tensor descriptor " | ||
| << dstTy; | ||
|
|
||
| int64_t tDescRank = dstTy.getRank(); | ||
| int64_t offsetSize = static_cast<int64_t>(getOffsets().size()); | ||
| int64_t constOffsetSize = | ||
| getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0; | ||
| if (((offsetSize != 0) && (offsetSize != tDescRank)) || | ||
| ((constOffsetSize != 0) && (constOffsetSize != tDescRank))) | ||
| return emitOpError( | ||
| "Mismatched ranks between offsets and tensor descriptor"); | ||
|
|
||
| return success(); | ||
| auto tDesc = getTensorDesc(); | ||
| return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(), | ||
| [&]() { return emitOpError(); }); | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.