@@ -121,22 +121,6 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
121121 return success ();
122122}
123123
124- // Verify that number of offsets matches either the source rank or the tdesc
125- // rank.
126- static LogicalResult
127- isValidNdOffset (TypedValue<TensorDescType> tDesc,
128- std::optional<llvm::ArrayRef<int64_t >> constOffsets,
129- int64_t offsetSize,
130- function_ref<InFlightDiagnostic()> emitError) {
131- int64_t constOffsetSize = constOffsets ? constOffsets->size () : 0 ;
132- auto tDescRank = tDesc.getType ().getRank ();
133- if (((offsetSize != 0 ) && (offsetSize < tDescRank)) ||
134- ((constOffsetSize != 0 ) && (constOffsetSize < tDescRank)))
135- return emitError () << " Offsets rank cannot be smaller than tensor "
136- " descriptor rank." ;
137- return success ();
138- }
139-
140124static LogicalResult
141125isValidGatherScatterBufferParams (Type offsetsTy, Type maskTy,
142126 VectorType valueTy, int64_t chunkSize,
@@ -274,10 +258,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
274258 auto [memrefStrides, _] = memrefTy.getStridesAndOffset ();
275259
276260 // if shape and strides are from Memref, we don't need attributes for them
277- // to keep the IR print clean (only do so for full-static case, otherwise
278- // printer would fail trying to print empty array-attr).
279- if (staticShape == memrefShape && staticStrides == memrefStrides &&
280- dynamicShape.empty () && dynamicStrides.empty ()) {
261+ // to keep the IR print clean.
262+ if (staticShape == memrefShape && staticStrides == memrefStrides) {
281263 staticShapeAttr = DenseI64ArrayAttr ();
282264 staticStridesAttr = DenseI64ArrayAttr ();
283265 }
@@ -338,10 +320,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
338320 auto [memrefStrides, _] = memrefTy.getStridesAndOffset ();
339321
340322 // if shape and strides are from Memref, we don't need attributes for them
341- // to keep the IR print clean (only do so for full-static case, otherwise
342- // printer would fail trying to print empty array-attr).
343- if (staticShape == memrefShape && staticStrides == memrefStrides &&
344- dynamicShape.empty () && dynamicStrides.empty ()) {
323+ // to keep the IR print clean.
324+ if (staticShape == memrefShape && staticStrides == memrefStrides) {
345325 staticShapeAttr = DenseI64ArrayAttr ();
346326 staticStridesAttr = DenseI64ArrayAttr ();
347327 }
@@ -491,9 +471,16 @@ LogicalResult PrefetchNdOp::verify() {
491471 if (!isReadHintOrNone (getL3HintAttr ()))
492472 return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
493473
494- auto tDesc = getTensorDesc ();
495- return isValidNdOffset (tDesc, getConstOffsets (), getMixedOffsets ().size (),
496- [&]() { return emitOpError (); });
474+ int64_t tDescRank = tdescTy.getRank ();
475+ int64_t offsetSize = static_cast <int64_t >(getOffsets ().size ());
476+ int64_t constOffsetSize =
477+ getConstOffsetsAttr () ? getConstOffsetsAttr ().size () : 0 ;
478+ if (((offsetSize != 0 ) && (offsetSize != tDescRank)) ||
479+ ((constOffsetSize != 0 ) && (constOffsetSize != tDescRank)))
480+ return emitOpError (
481+ " Mismatched ranks between offsets and tensor descriptor" );
482+
483+ return success ();
497484}
498485
499486// ===----------------------------------------------------------------------===//
@@ -609,9 +596,16 @@ LogicalResult LoadNdOp::verify() {
609596 << " is not consistent with tensor descriptor "
610597 << tdescTy;
611598
612- auto tDesc = getTensorDesc ();
613- return isValidNdOffset (tDesc, getConstOffsets (), getMixedOffsets ().size (),
614- [&]() { return emitOpError (); });
599+ int64_t tDescRank = tdescTy.getRank ();
600+ int64_t offsetSize = static_cast <int64_t >(getOffsets ().size ());
601+ int64_t constOffsetSize =
602+ getConstOffsetsAttr () ? getConstOffsetsAttr ().size () : 0 ;
603+ if (((offsetSize != 0 ) && (offsetSize != tDescRank)) ||
604+ ((constOffsetSize != 0 ) && (constOffsetSize != tDescRank)))
605+ return emitOpError (
606+ " Mismatched ranks between offsets and tensor descriptor" );
607+
608+ return success ();
615609}
616610
617611// ===----------------------------------------------------------------------===//
@@ -696,9 +690,16 @@ LogicalResult StoreNdOp::verify() {
696690 << " is not consistent with tensor descriptor "
697691 << dstTy;
698692
699- auto tDesc = getTensorDesc ();
700- return isValidNdOffset (tDesc, getConstOffsets (), getMixedOffsets ().size (),
701- [&]() { return emitOpError (); });
693+ int64_t tDescRank = dstTy.getRank ();
694+ int64_t offsetSize = static_cast <int64_t >(getOffsets ().size ());
695+ int64_t constOffsetSize =
696+ getConstOffsetsAttr () ? getConstOffsetsAttr ().size () : 0 ;
697+ if (((offsetSize != 0 ) && (offsetSize != tDescRank)) ||
698+ ((constOffsetSize != 0 ) && (constOffsetSize != tDescRank)))
699+ return emitOpError (
700+ " Mismatched ranks between offsets and tensor descriptor" );
701+
702+ return success ();
702703}
703704
704705// ===----------------------------------------------------------------------===//
0 commit comments