Skip to content

Conversation

@Jianhui-Li
Copy link
Contributor

@Jianhui-Li Jianhui-Li commented Jul 25, 2025

Add variant of load/store/prefetch to allow offset. The new xegpu.load variant accepts memref+offset, and the existing tdesc operand will be removed in the future PR.

The semantics are combination of "creating scattered_tdesc + xegpu.load with scattered_tdesc". The current xegpu.load accepts tdesc operand, which encapsulates "memref+offset". This PR "fold" "memref+offset" directly to xegpu.load replacing "tdesc". Create_tdesc will be removed as scatter_tdesc only contains base address after offsets being taken away, so there is no point to keep it.

    // wi level code example
    %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64,  vector<1xindex>, vector<1xi1> -> vector<2xf32>
    xegpu.store %val, %src[%offsets], %mask: vector<1xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
    xegpu.prefetch %src[%0] : ui64, vector<1xindex>

Copy link
Contributor

@chencha3 chencha3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

let genVerifyDecl = 1;
}

def XeGPU_TensorDesc_or_MemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: keep a consistent naming convention: XeGPU_TensorDescOrMemRef

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

include "mlir/Interfaces/ShapedOpInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure whether this header include is necessary, So far, I didn't see the changes requiring this. Maybe It can be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed


auto maskShape = getShapeOf(maskTy);
auto valueShape = getShapeOf(valueTy);
auto memShape = getShapeOf(memTy);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems memShape is not used. Maybe it can be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a bit too much happening here.
Could this PR be split?

Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall LGTM. please address the comments.

return getSource().getType();
}

Value getTensorDesc() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this method is needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to minimize the change during the transition.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is bit risky. because the caller might expect a tensor_desc but this may return a memref or I64. I would add a TODO note.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it returns one specific variant of the source, I'd also suggest it return a TypedValue<xegpu::TensorDescType>

return getSource();
}

xegpu::TensorDescType getTensorDescType() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here dyn_cast is involved. Maybe better to rename the function to TryGetTensorDesc to avoid confusion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function name needs to keep the same to minimize the change.

}

static LogicalResult
isValidGatherScatterRawptrParams(Type maskTy, VectorType valueTy,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two functions are very similar. I think we can reuse/refactor isValidGatherScatterParams to achive this. I don't see a need to define 2 new functions.

If that is hard to do, at least consider moving common logic to a helper and reuse the helper.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactored to one function.

xegpu::TensorDescType tdescTy = op.getTensorDescType();

if (!tdescTy.isScattered())
if (!tdescTy || !tdescTy.isScattered())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Looks like the new version is supported by this pattern? Maybe add a TODO note will help here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

}

// -----
func.func @store_scatter_offset_sg(%src: memref<?xf16>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a wi test case like the one above? mask and offsets are just 1 x type

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed.

@Jianhui-Li Jianhui-Li changed the title [MLIR][XeGPU] Add offsets to load/store/prefetch [MLIR][XeGPU] Allow load/store/prefetch uses [memref+offset] instead of tdesc Jul 29, 2025
@Jianhui-Li
Copy link
Contributor Author

@adam-smnk @charithaintc Thanks for the detailed review! I have added documentation and updated the PR description. Also changed the type name and refactor the verifier.

@charithaintc charithaintc self-requested a review July 29, 2025 17:45
Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. please address the minor comments if possible.

Comment on lines 675 to 683

if (tdescTy) {
if (!tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
} else {
if (getRankOf(getSource()) > 1)
return emitOpError(
"Expecting the source is a 1D memref or pointer (uint64_t).");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (tdescTy) {
if (!tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
} else {
if (getRankOf(getSource()) > 1)
return emitOpError(
"Expecting the source is a 1D memref or pointer (uint64_t).");
}
if (tdescTy && !tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
if (getRankOf(getSource()) > 1)
return emitOpError(
"Expecting the source is a 1D memref or pointer (uint64_t).");

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The suggested code doesn't work since the Tdesc can be 2d.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (tdescTy) {
if (!tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
} else {
if (getRankOf(getSource()) > 1)
return emitOpError(
"Expecting the source is a 1D memref or pointer (uint64_t).");
}
if (tdescTy && !tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
if (!tdescTy && getRankOf(getSource()) > 1)
return emitOpError(
"Expecting the source is a 1D memref or pointer (uint64_t).");

Comment on lines 712 to 719
if (tdescTy) {
if (!tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
} else {
if (getRankOf(getSource()) > 1)
return emitOpError(
"Expecting the source is a 1D memref or pointer (uint64_t).");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check the suggested change above. in summary, we should try to combine conditions and return early rather than nesting if conditions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

Comment on lines 761 to 768
if (tdescTy) {
if (!tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
} else {
if (getRankOf(getDest()) > 1)
return emitOpError(
"Expecting the dest is a 1D memref or pointer (uint64_t).");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check suggested code change.

return getSource().getType();
}

Value getTensorDesc() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is bit risky. because the caller might expect a tensor_desc but this may return a memref or I64. I would add a TODO note.

return getSource().getType();
}

Value getTensorDesc() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it returns one specific variant of the source, I'd also suggest it return a TypedValue<xegpu::TensorDescType>

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % previous nits

@Jianhui-Li Jianhui-Li merged commit e6f360b into llvm:main Jul 30, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants