Skip to content

Commit 04306ca

Browse files
committed
address comments
1 parent 8b99ecc commit 04306ca

File tree

3 files changed

+31
-14
lines changed

3 files changed

+31
-14
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,10 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
637637
```
638638

639639
Example 2:
640-
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc". The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref.
640+
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
641+
It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
642+
The source operand could be a raw pointer (uint64_t).
643+
Please refer to create_tdesc for the restriction of memref.
641644
```mlir
642645
%a = memref.alloc() : memref<1024xf32>
643646
%0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
@@ -660,8 +663,11 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
660663
return getSource().getType();
661664
}
662665

663-
Value getTensorDesc() {
664-
return getSource();
666+
TypedValue<xegpu::TensorDescType> getTensorDesc() {
667+
if (auto tdescType = getTensorDescType()) {
668+
return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource());
669+
}
670+
return TypedValue<xegpu::TensorDescType>();
665671
}
666672

667673
xegpu::TensorDescType getTensorDescType() {
@@ -728,7 +734,10 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
728734
```
729735

730736
Example 4:
731-
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc". The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref.
737+
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
738+
It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
739+
The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
740+
for the restriction of memref.
732741
```mlir
733742
%a = memref.alloc() : memref<1024xf32>
734743
%offsets = vector.step : vector<16xindex>
@@ -756,8 +765,11 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
756765
return getSource().getType();
757766
}
758767

759-
Value getTensorDesc() {
760-
return getSource();
768+
TypedValue<xegpu::TensorDescType> getTensorDesc() {
769+
if (auto tdescType = getTensorDescType()) {
770+
return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource());
771+
}
772+
return TypedValue<xegpu::TensorDescType>();
761773
}
762774

763775
xegpu::TensorDescType getTensorDescType() {
@@ -833,7 +845,10 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
833845
```
834846

835847
Example 4:
836-
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc". The dest operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref.
848+
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
849+
It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
850+
The dest operand could be a raw pointer (uint64_t).
851+
Please refer to create_tdesc for the restriction of memref.
837852
```mlir
838853
%a = memref.alloc() : memref<1024xf32>
839854
%val = arith.constant dense<0.0> : vector<16xf32>
@@ -862,9 +877,11 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
862877
return getDest().getType();
863878
}
864879

865-
Value getTensorDesc() {
866-
assert(getTensorDescType() && "Expected dest to be a TensorDescType");
867-
return getDest();
880+
TypedValue<xegpu::TensorDescType> getTensorDesc() {
881+
if (auto tdescType = getTensorDescType()) {
882+
return llvm::cast<TypedValue<xegpu::TensorDescType>>(getDest());
883+
}
884+
return TypedValue<xegpu::TensorDescType>();
868885
}
869886

870887
xegpu::TensorDescType getTensorDescType() {

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,7 @@ LogicalResult LoadGatherOp::verify() {
708708
auto valueTy = getValueType();
709709

710710
if (tdescTy && !tdescTy.isScattered())
711-
return emitOpError("Expects a scattered TensorDesc.\n");
711+
return emitOpError("Expects a scattered TensorDesc.");
712712

713713
if (!tdescTy && getRankOf(getSource()) > 1)
714714
return emitOpError(

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
485485
xegpu::TensorDescType tdescTy = op.getTensorDescType();
486486

487487
// TODO: handle the unstructure source case (!tdesTy)
488-
if (!tdescTy || !tdescTy.isScattered())
488+
if (!tdescTy || op.getOffsets())
489489
return failure();
490490

491491
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -548,7 +548,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
548548
xegpu::TensorDescType tdescTy = op.getTensorDescType();
549549

550550
// TODO: handle the unstructure source case (!tdesTy)
551-
if (!tdescTy || !tdescTy.isScattered())
551+
if (!tdescTy || op.getOffsets())
552552
return failure();
553553

554554
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -578,7 +578,7 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
578578
xegpu::TensorDescType tdescTy = op.getTensorDescType();
579579

580580
// TODO: handle the unstructure source case (!tdesTy)
581-
if (!tdescTy || !tdescTy.isScattered())
581+
if (!tdescTy || op.getOffsets())
582582
return failure();
583583

584584
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);

0 commit comments

Comments
 (0)