-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][XeGPU] Allow load/store/prefetch uses [memref+offset] instead of tdesc #150576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
chencha3
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
| let genVerifyDecl = 1; | ||
| } | ||
|
|
||
| def XeGPU_TensorDesc_or_MemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: keep a consistent naming convention: XeGPU_TensorDescOrMemRef
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
| include "mlir/Interfaces/ShapedOpInterfaces.td" | ||
| include "mlir/Interfaces/SideEffectInterfaces.td" | ||
| include "mlir/Interfaces/ViewLikeInterface.td" | ||
| include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure whether this header include is necessary, So far, I didn't see the changes requiring this. Maybe It can be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
|
|
||
| auto maskShape = getShapeOf(maskTy); | ||
| auto valueShape = getShapeOf(valueTy); | ||
| auto memShape = getShapeOf(memTy); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems memShape is not used. Maybe it can be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
adam-smnk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a bit too much happening here.
Could this PR be split?
charithaintc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall LGTM. please address the comments.
| return getSource().getType(); | ||
| } | ||
|
|
||
| Value getTensorDesc() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this method is needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to minimize the change during the transition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this is bit risky. because the caller might expect a tensor_desc but this may return a memref or I64. I would add a TODO note.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it returns one specific variant of the source, I'd also suggest it return a TypedValue<xegpu::TensorDescType>
| return getSource(); | ||
| } | ||
|
|
||
| xegpu::TensorDescType getTensorDescType() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here dyn_cast is involved. Maybe better to rename the function to TryGetTensorDesc to avoid confusion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function name needs to keep the same to minimize the change.
| } | ||
|
|
||
| static LogicalResult | ||
| isValidGatherScatterRawptrParams(Type maskTy, VectorType valueTy, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these two functions are very similar. I think we can reuse/refactor isValidGatherScatterParams to achive this. I don't see a need to define 2 new functions.
If that is hard to do, at least consider moving common logic to a helper and reuse the helper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
refactored to one function.
| xegpu::TensorDescType tdescTy = op.getTensorDescType(); | ||
|
|
||
| if (!tdescTy.isScattered()) | ||
| if (!tdescTy || !tdescTy.isScattered()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Looks like the new version is supported by this pattern? Maybe add a TODO note will help here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
mlir/test/Dialect/XeGPU/invalid.mlir
Outdated
| } | ||
|
|
||
| // ----- | ||
| func.func @store_scatter_offset_sg(%src: memref<?xf16>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this a wi test case like the one above? mask and offsets are just 1 x type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed.
|
@adam-smnk @charithaintc Thanks for the detailed review! I have added documentation and updated the PR description. Also changed the type name and refactor the verifier. |
charithaintc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. please address the minor comments if possible.
|
|
||
| if (tdescTy) { | ||
| if (!tdescTy.isScattered()) | ||
| return emitOpError("Expects a scattered TensorDesc.\n"); | ||
| } else { | ||
| if (getRankOf(getSource()) > 1) | ||
| return emitOpError( | ||
| "Expecting the source is a 1D memref or pointer (uint64_t)."); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if (tdescTy) { | |
| if (!tdescTy.isScattered()) | |
| return emitOpError("Expects a scattered TensorDesc.\n"); | |
| } else { | |
| if (getRankOf(getSource()) > 1) | |
| return emitOpError( | |
| "Expecting the source is a 1D memref or pointer (uint64_t)."); | |
| } | |
| if (tdescTy && !tdescTy.isScattered()) | |
| return emitOpError("Expects a scattered TensorDesc.\n"); | |
| if (getRankOf(getSource()) > 1) | |
| return emitOpError( | |
| "Expecting the source is a 1D memref or pointer (uint64_t)."); | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The suggested code doesn't work since the Tdesc can be 2d.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if (tdescTy) { | |
| if (!tdescTy.isScattered()) | |
| return emitOpError("Expects a scattered TensorDesc.\n"); | |
| } else { | |
| if (getRankOf(getSource()) > 1) | |
| return emitOpError( | |
| "Expecting the source is a 1D memref or pointer (uint64_t)."); | |
| } | |
| if (tdescTy && !tdescTy.isScattered()) | |
| return emitOpError("Expects a scattered TensorDesc.\n"); | |
| if (!tdescTy && getRankOf(getSource()) > 1) | |
| return emitOpError( | |
| "Expecting the source is a 1D memref or pointer (uint64_t)."); |
| if (tdescTy) { | ||
| if (!tdescTy.isScattered()) | ||
| return emitOpError("Expects a scattered TensorDesc.\n"); | ||
| } else { | ||
| if (getRankOf(getSource()) > 1) | ||
| return emitOpError( | ||
| "Expecting the source is a 1D memref or pointer (uint64_t)."); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check the suggested change above. in summary, we should try to combine conditions and return early rather than nesting if conditions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
| if (tdescTy) { | ||
| if (!tdescTy.isScattered()) | ||
| return emitOpError("Expects a scattered TensorDesc.\n"); | ||
| } else { | ||
| if (getRankOf(getDest()) > 1) | ||
| return emitOpError( | ||
| "Expecting the dest is a 1D memref or pointer (uint64_t)."); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check suggested code change.
| return getSource().getType(); | ||
| } | ||
|
|
||
| Value getTensorDesc() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this is bit risky. because the caller might expect a tensor_desc but this may return a memref or I64. I would add a TODO note.
| return getSource().getType(); | ||
| } | ||
|
|
||
| Value getTensorDesc() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it returns one specific variant of the source, I'd also suggest it return a TypedValue<xegpu::TensorDescType>
adam-smnk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % previous nits
Add variant of load/store/prefetch to allow offset. The new xegpu.load variant accepts memref+offset, and the existing tdesc operand will be removed in the future PR.
The semantics are combination of "creating scattered_tdesc + xegpu.load with scattered_tdesc". The current xegpu.load accepts tdesc operand, which encapsulates "memref+offset". This PR "fold" "memref+offset" directly to xegpu.load replacing "tdesc". Create_tdesc will be removed as scatter_tdesc only contains base address after offsets being taken away, so there is no point to keep it.