@@ -128,29 +128,12 @@ isValidNdOffset(TypedValue<TensorDescType> tDesc,
128128 std::optional<llvm::ArrayRef<int64_t >> constOffsets,
129129 int64_t offsetSize,
130130 function_ref<InFlightDiagnostic()> emitError) {
131- if (auto createTDescOp = tDesc.getDefiningOp <CreateNdDescOp>()) {
132- // If CreateNdDescOp is available, we can further
133- // check the offsets rank against the source rank.
134- auto staticSource = createTDescOp.getConstShapeAttr ();
135- int64_t sourceRank;
136- if (!staticSource || staticSource.empty ()) {
137- auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType ());
138- sourceRank = sourceTy.getRank ();
139- } else
140- sourceRank = staticSource.size ();
141-
142- int64_t constOffsetSize = constOffsets ? constOffsets->size () : 0 ;
143- auto tDescRank = tDesc.getType ().getRank ();
144- bool sourceRankMismatch =
145- ((offsetSize != 0 ) && (offsetSize != sourceRank)) ||
146- ((constOffsetSize != 0 ) && (constOffsetSize != sourceRank));
147- bool tdescRankMismatch =
148- ((offsetSize != 0 ) && (offsetSize != tDescRank)) ||
149- ((constOffsetSize != 0 ) && (constOffsetSize != tDescRank));
150- if (sourceRankMismatch && tdescRankMismatch)
151- return emitError () << " Offsets rank must match either the source or the "
152- " TensorDesc rank." ;
153- }
131+ int64_t constOffsetSize = constOffsets ? constOffsets->size () : 0 ;
132+ auto tDescRank = tDesc.getType ().getRank ();
133+ if (((offsetSize != 0 ) && (offsetSize < tDescRank)) ||
134+ ((constOffsetSize != 0 ) && (constOffsetSize < tDescRank)))
135+ return emitError () << " Offsets rank cannot be smaller than tensor "
136+ " descriptor rank." ;
154137 return success ();
155138}
156139
0 commit comments