Skip to content

Commit beeac48

Browse files
committed
Update validation to not depend on 'create_nd_tdesc' op
Signed-off-by: dchigarev <[email protected]>
1 parent 614887b commit beeac48

File tree

1 file changed

+6
-23
lines changed

1 file changed

+6
-23
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -128,29 +128,12 @@ isValidNdOffset(TypedValue<TensorDescType> tDesc,
128128
std::optional<llvm::ArrayRef<int64_t>> constOffsets,
129129
int64_t offsetSize,
130130
function_ref<InFlightDiagnostic()> emitError) {
131-
if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
132-
// If CreateNdDescOp is available, we can further
133-
// check the offsets rank against the source rank.
134-
auto staticSource = createTDescOp.getConstShapeAttr();
135-
int64_t sourceRank;
136-
if (!staticSource || staticSource.empty()) {
137-
auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
138-
sourceRank = sourceTy.getRank();
139-
} else
140-
sourceRank = staticSource.size();
141-
142-
int64_t constOffsetSize = constOffsets ? constOffsets->size() : 0;
143-
auto tDescRank = tDesc.getType().getRank();
144-
bool sourceRankMismatch =
145-
((offsetSize != 0) && (offsetSize != sourceRank)) ||
146-
((constOffsetSize != 0) && (constOffsetSize != sourceRank));
147-
bool tdescRankMismatch =
148-
((offsetSize != 0) && (offsetSize != tDescRank)) ||
149-
((constOffsetSize != 0) && (constOffsetSize != tDescRank));
150-
if (sourceRankMismatch && tdescRankMismatch)
151-
return emitError() << "Offsets rank must match either the source or the "
152-
"TensorDesc rank.";
153-
}
131+
int64_t constOffsetSize = constOffsets ? constOffsets->size() : 0;
132+
auto tDescRank = tDesc.getType().getRank();
133+
if (((offsetSize != 0) && (offsetSize < tDescRank)) ||
134+
((constOffsetSize != 0) && (constOffsetSize < tDescRank)))
135+
return emitError() << "Offsets rank cannot be smaller than tensor "
136+
"descriptor rank.";
154137
return success();
155138
}
156139

0 commit comments

Comments
 (0)