Skip to content

Commit 8b99ecc

Browse files
committed
minor polish
1 parent da7142a commit 8b99ecc

File tree

2 files changed

+20
-25
lines changed

2 files changed

+20
-25
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -863,7 +863,8 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
863863
}
864864

865865
Value getTensorDesc() {
866-
return getDest();
866+
assert(getTensorDescType() && "Expected dest to be a TensorDescType");
867+
return getDest();
867868
}
868869

869870
xegpu::TensorDescType getTensorDescType() {

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -673,14 +673,12 @@ LogicalResult CreateDescOp::verify() {
673673
LogicalResult PrefetchOp::verify() {
674674
auto tdescTy = getTensorDescType();
675675

676-
if (tdescTy) {
677-
if (!tdescTy.isScattered())
678-
return emitOpError("Expects a scattered TensorDesc.\n");
679-
} else {
680-
if (getRankOf(getSource()) > 1)
681-
return emitOpError(
682-
"Expecting the source is a 1D memref or pointer (uint64_t).");
683-
}
676+
if (tdescTy && !tdescTy.isScattered())
677+
return emitOpError("Expects a scattered TensorDesc.\n");
678+
679+
if (!tdescTy && getRankOf(getSource()) > 1)
680+
return emitOpError(
681+
"Expecting the source is a 1D memref or pointer (uint64_t).");
684682

685683
if (!isReadHintOrNone(getL1HintAttr()))
686684
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -709,14 +707,12 @@ LogicalResult LoadGatherOp::verify() {
709707
auto maskTy = getMaskType();
710708
auto valueTy = getValueType();
711709

712-
if (tdescTy) {
713-
if (!tdescTy.isScattered())
714-
return emitOpError("Expects a scattered TensorDesc.\n");
715-
} else {
716-
if (getRankOf(getSource()) > 1)
717-
return emitOpError(
718-
"Expecting the source is a 1D memref or pointer (uint64_t).");
719-
}
710+
if (tdescTy && !tdescTy.isScattered())
711+
return emitOpError("Expects a scattered TensorDesc.\n");
712+
713+
if (!tdescTy && getRankOf(getSource()) > 1)
714+
return emitOpError(
715+
"Expecting the source is a 1D memref or pointer (uint64_t).");
720716

721717
if (!isReadHintOrNone(getL1HintAttr()))
722718
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -758,14 +754,12 @@ LogicalResult StoreScatterOp::verify() {
758754
auto maskTy = getMaskType();
759755
auto valueTy = getValueType();
760756

761-
if (tdescTy) {
762-
if (!tdescTy.isScattered())
763-
return emitOpError("Expects a scattered TensorDesc.\n");
764-
} else {
765-
if (getRankOf(getDest()) > 1)
766-
return emitOpError(
767-
"Expecting the dest is a 1D memref or pointer (uint64_t).");
768-
}
757+
if (tdescTy && !tdescTy.isScattered())
758+
return emitOpError("Expects a scattered TensorDesc.\n");
759+
760+
if (!tdescTy && getRankOf(getDest()) > 1)
761+
return emitOpError(
762+
"Expecting the dest is a 1D memref or pointer (uint64_t).");
769763

770764
if (!isWriteHintOrNone(getL1HintAttr()))
771765
return emitOpError("invalid l1_hint: ") << getL1HintAttr();

0 commit comments

Comments
 (0)