@@ -121,6 +121,39 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
121121 return success ();
122122}
123123
124+ // Verify that number of offsets matches either the source rank or the tdesc
125+ // rank.
126+ static LogicalResult
127+ isValidNdOffset (TypedValue<TensorDescType> tDesc,
128+ std::optional<llvm::ArrayRef<long int >> constOffsets,
129+ int64_t offsetSize,
130+ function_ref<InFlightDiagnostic()> emitError) {
131+ if (auto createTDescOp = tDesc.getDefiningOp <CreateNdDescOp>()) {
132+ // If CreateNdDescOp is available, we can further
133+ // check the offsets rank against the source rank.
134+ auto staticSource = createTDescOp.getConstShapeAttr ();
135+ int64_t sourceRank;
136+ if (!staticSource || staticSource.empty ()) {
137+ auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType ());
138+ sourceRank = sourceTy.getRank ();
139+ } else
140+ sourceRank = staticSource.size ();
141+
142+ int64_t constOffsetSize = constOffsets ? constOffsets->size () : 0 ;
143+ auto tDescRank = tDesc.getType ().getRank ();
144+ bool sourceRankMismatch =
145+ ((offsetSize != 0 ) && (offsetSize != sourceRank)) ||
146+ ((constOffsetSize != 0 ) && (constOffsetSize != sourceRank));
147+ bool tdescRankMismatch =
148+ ((offsetSize != 0 ) && (offsetSize != tDescRank)) ||
149+ ((constOffsetSize != 0 ) && (constOffsetSize != tDescRank));
150+ if (sourceRankMismatch && tdescRankMismatch)
151+ return emitError () << " Offsets rank must match either the source or the "
152+ " TensorDesc rank." ;
153+ }
154+ return success ();
155+ }
156+
124157static LogicalResult
125158isValidGatherScatterBufferParams (Type offsetsTy, Type maskTy,
126159 VectorType valueTy, int64_t chunkSize,
@@ -433,33 +466,8 @@ LogicalResult PrefetchNdOp::verify() {
433466 return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
434467
435468 auto tDesc = getTensorDesc ();
436- if (auto createTDescOp = tDesc.getDefiningOp <CreateNdDescOp>()) {
437- // If CreateNdDescOp is available, we can further
438- // check the offsets rank against the source rank.
439- auto staticSource = createTDescOp.getConstShapeAttr ();
440- int64_t sourceRank;
441- if (!staticSource || staticSource.empty ()) {
442- auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType ());
443- sourceRank = sourceTy.getRank ();
444- } else
445- sourceRank = staticSource.size ();
446-
447- int64_t offsetSize = static_cast <int64_t >(getOffsets ().size ());
448- int64_t constOffsetSize =
449- getConstOffsetsAttr () ? getConstOffsetsAttr ().size () : 0 ;
450- auto tDescRank = tdescTy.getRank ();
451- bool sourceRankMismatch =
452- ((offsetSize != 0 ) && (offsetSize != sourceRank)) ||
453- ((constOffsetSize != 0 ) && (constOffsetSize != sourceRank));
454- bool tdescRankMismatch =
455- ((offsetSize != 0 ) && (offsetSize != tDescRank)) ||
456- ((constOffsetSize != 0 ) && (constOffsetSize != tDescRank));
457- if (sourceRankMismatch && tdescRankMismatch)
458- return emitOpError (
459- " Offsets rank must match either the source or the TensorDesc rank." );
460- }
461-
462- return success ();
469+ return isValidNdOffset (tDesc, getConstOffsets (), getMixedOffsets ().size (),
470+ [&]() { return emitOpError (); });
463471}
464472
465473// ===----------------------------------------------------------------------===//
@@ -576,33 +584,8 @@ LogicalResult LoadNdOp::verify() {
576584 << tdescTy;
577585
578586 auto tDesc = getTensorDesc ();
579- if (auto createTDescOp = tDesc.getDefiningOp <CreateNdDescOp>()) {
580- // If CreateNdDescOp is available, we can further
581- // check the offsets rank against the source rank.
582- auto staticSource = createTDescOp.getConstShapeAttr ();
583- int64_t sourceRank;
584- if (!staticSource || staticSource.empty ()) {
585- auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType ());
586- sourceRank = sourceTy.getRank ();
587- } else
588- sourceRank = staticSource.size ();
589-
590- int64_t offsetSize = static_cast <int64_t >(getOffsets ().size ());
591- int64_t constOffsetSize =
592- getConstOffsetsAttr () ? getConstOffsetsAttr ().size () : 0 ;
593- auto tDescRank = tdescTy.getRank ();
594- bool sourceRankMismatch =
595- ((offsetSize != 0 ) && (offsetSize != sourceRank)) ||
596- ((constOffsetSize != 0 ) && (constOffsetSize != sourceRank));
597- bool tdescRankMismatch =
598- ((offsetSize != 0 ) && (offsetSize != tDescRank)) ||
599- ((constOffsetSize != 0 ) && (constOffsetSize != tDescRank));
600- if (sourceRankMismatch && tdescRankMismatch)
601- return emitOpError (
602- " Offsets rank must match either the source or the TensorDesc rank." );
603- }
604-
605- return success ();
587+ return isValidNdOffset (tDesc, getConstOffsets (), getMixedOffsets ().size (),
588+ [&]() { return emitOpError (); });
606589}
607590
608591// ===----------------------------------------------------------------------===//
@@ -688,33 +671,8 @@ LogicalResult StoreNdOp::verify() {
688671 << dstTy;
689672
690673 auto tDesc = getTensorDesc ();
691- if (auto createTDescOp = tDesc.getDefiningOp <CreateNdDescOp>()) {
692- // If CreateNdDescOp is available, we can further
693- // check the offsets rank against the source rank.
694- auto staticSource = createTDescOp.getConstShapeAttr ();
695- int64_t sourceRank;
696- if (!staticSource || staticSource.empty ()) {
697- auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType ());
698- sourceRank = sourceTy.getRank ();
699- } else
700- sourceRank = staticSource.size ();
701-
702- int64_t offsetSize = static_cast <int64_t >(getOffsets ().size ());
703- int64_t constOffsetSize =
704- getConstOffsetsAttr () ? getConstOffsetsAttr ().size () : 0 ;
705- auto tDescRank = dstTy.getRank ();
706- bool sourceRankMismatch =
707- ((offsetSize != 0 ) && (offsetSize != sourceRank)) ||
708- ((constOffsetSize != 0 ) && (constOffsetSize != sourceRank));
709- bool tdescRankMismatch =
710- ((offsetSize != 0 ) && (offsetSize != tDescRank)) ||
711- ((constOffsetSize != 0 ) && (constOffsetSize != tDescRank));
712- if (sourceRankMismatch && tdescRankMismatch)
713- return emitOpError (
714- " Offsets rank must match either the source or the TensorDesc rank." );
715- }
716-
717- return success ();
674+ return isValidNdOffset (tDesc, getConstOffsets (), getMixedOffsets ().size (),
675+ [&]() { return emitOpError (); });
718676}
719677
720678// ===----------------------------------------------------------------------===//
0 commit comments