@@ -133,29 +133,35 @@ static bool isSmallerThan(ArrayRef<int64_t> sourceShape,
133133 });
134134}
135135
136- // ===----------------------------------------------------------------------===//
137- // ScatterOp
138- // ===----------------------------------------------------------------------===//
139-
140- LogicalResult ScatterOp::verify () {
141- Operation *op = getOperation ();
142- if (getInputs ().size () != 2 ) {
143- return op->emitOpError (" expected two input operands" );
136+ // / Helper function to verify both `scatter` and `gather`. Since both ops share
137+ // / the same sementics, we can use the same function to verify them. Note: this
138+ // / is written from the perspective of `scatter` op. For gather, `updateType`
139+ // / maps to the type of the output and `originalType` maps to the type of the
140+ // / `source`.
141+ template <typename OpTy>
142+ static LogicalResult
143+ verifyGatherScatter (OpTy op, int64_t sliceRank, ShapedType originalType,
144+ ShapedType updateType, StringRef originalName,
145+ StringRef updateName) {
146+ static_assert (llvm::is_one_of<OpTy, GatherOp, ScatterOp>::value,
147+ " applies to only gather or scatter operations" );
148+ if (op.getInputs ().size () != 2 ) {
149+ return op.emitOpError (" expected two input operands" );
144150 }
145- if (getOutputs ().size () != 1 ) {
146- return op-> emitOpError (" expected one output operand" );
151+ if (op. getOutputs ().size () != 1 ) {
152+ return op. emitOpError (" expected one output operand" );
147153 }
148154
149- auto indicesType = getIndicesType ();
155+ auto indicesType = op. getIndicesType ();
150156 if (indicesType.getRank () < 1 ||
151157 !isa<IntegerType>(indicesType.getElementType ())) {
152158 return op->emitOpError (" expected indices to be of rank 1 or greater and of "
153159 " integer element type" );
154160 }
155161
156- ArrayRef<int64_t > dimMap = getDimensionMap ();
162+ ArrayRef<int64_t > dimMap = op. getDimensionMap ();
157163 if (failed (isPermSequence (
158- [&]() { return this ->emitOpError (" dimension map is invalid." ); },
164+ [&]() { return op ->emitOpError (" dimension map is invalid." ); },
159165 dimMap))) {
160166 return failure ();
161167 }
@@ -164,23 +170,24 @@ LogicalResult ScatterOp::verify() {
164170 return op->emitOpError (" dimension map must have at least one element" );
165171 }
166172
167- const size_t indexDepth = getIndexDepth ();
168- auto originalType = getOriginalType ();
169- auto updateType = getUpdateType ();
173+ const size_t indexDepth = op.getIndexDepth ();
170174 const auto originalSliceRank = originalType.getRank () - indexDepth;
171175 if (originalSliceRank < 0 ) {
172- return op->emitOpError (
173- " expected original rank to be greater or equal to index depth" );
176+ return op->emitOpError (" expected " + originalName +
177+ " rank to be greater or equal to index depth" );
174178 }
175179 if (updateType.getRank () < originalSliceRank) {
176- return op->emitOpError (
177- " expected update to be at least the rank of non indexed original dims" );
180+ return op->emitOpError (" expected " + updateName +
181+ " to be at least the rank of non indexed " +
182+ originalName + " dims" );
178183 }
179184 const size_t batchRank = updateType.getRank () - originalSliceRank;
180185
181186 if (updateType.getRank () - batchRank != originalSliceRank) {
182- return op->emitOpError (" expected rank of update value - batch rank to be "
183- " equal to rank of original value - index depth" );
187+ return op->emitOpError (" expected rank of " + updateName +
188+ " value - batch rank to be "
189+ " equal to rank of " +
190+ originalName + " value - index depth" );
184191 }
185192
186193 if ((indicesType.getRank () != batchRank || indexDepth != 1 ) &&
@@ -196,8 +203,8 @@ LogicalResult ScatterOp::verify() {
196203 llvm::mismatch (indicesType.getShape ().take_front (batchRank),
197204 updateType.getShape ().take_front (batchRank));
198205 if (indicesIt != indicesType.getShape ().take_front (batchRank).end ()) {
199- return op->emitOpError (
200- " mismatch in shape of indices and update value at dim#" )
206+ return op->emitOpError (" mismatch in shape of indices and " + updateName +
207+ " value at dim#" )
201208 << (indicesIt - indicesType.getShape ().begin ());
202209 }
203210 }
@@ -208,7 +215,7 @@ LogicalResult ScatterOp::verify() {
208215 }
209216
210217 {
211- for (auto idx : llvm::seq<int64_t >(0 , getUpdateSliceRank () )) {
218+ for (auto idx : llvm::seq<int64_t >(0 , sliceRank )) {
212219 int64_t updateDim = idx + batchRank;
213220 int64_t origDim = idx + indexDepth;
214221 if (originalType.isDynamicDim (origDim) ||
@@ -217,14 +224,14 @@ LogicalResult ScatterOp::verify() {
217224 }
218225 if (originalType.getDimSize (origDim) !=
219226 updateType.getDimSize (updateDim)) {
220- return op->emitOpError (" shape of update value dim#" )
221- << (updateDim) << " must match original value at dim# "
222- << (origDim);
227+ return op->emitOpError (" shape of " + updateName + " value dim#" )
228+ << (updateDim)
229+ << " must match " + originalName + " value at dim# " << (origDim);
223230 }
224231 }
225232 }
226233
227- Region ®ion = this -> getRegion ();
234+ Region ®ion = op. getRegion ();
228235 Block *body = ®ion.front ();
229236 if (body->getNumArguments () != 2 ) {
230237 return op->emitOpError (" expected region to have two arguments" );
@@ -238,12 +245,12 @@ LogicalResult ScatterOp::verify() {
238245 }
239246 if (arg0Type != updateType.getElementType ()) {
240247 return op->emitOpError (" mismatch in argument 0 of region " )
241- << arg0Type << " and element type of update value "
248+ << arg0Type << " and element type of " + updateName + " value "
242249 << updateType.getElementType ();
243250 }
244251 if (arg1Type != originalType.getElementType ()) {
245252 return op->emitOpError (" mismatch in argument 1 of region " )
246- << arg1Type << " and element type of original value "
253+ << arg1Type << " and element type of " + originalName + " value "
247254 << originalType.getElementType ();
248255 }
249256 if (arg0Type != arg1Type) {
@@ -262,6 +269,15 @@ LogicalResult ScatterOp::verify() {
262269 return success ();
263270}
264271
272+ // ===----------------------------------------------------------------------===//
273+ // ScatterOp
274+ // ===----------------------------------------------------------------------===//
275+
276+ LogicalResult ScatterOp::verify () {
277+ return verifyGatherScatter (*this , getUpdateSliceRank (), getOriginalType (),
278+ getUpdateType (), " original" , " update" );
279+ }
280+
265281LogicalResult
266282ScatterOp::reifyResultShapes (OpBuilder &b,
267283 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
@@ -285,6 +301,22 @@ SmallVector<AffineMap> ScatterOp::getIndexingMapsForResults() {
285301 return {AffineMap (nullptr )};
286302}
287303
304+ // ===----------------------------------------------------------------------===//
305+ // GatherOp
306+ // ===----------------------------------------------------------------------===//
307+
308+ LogicalResult GatherOp::verify () {
309+ return verifyGatherScatter (*this , getOutputSliceRank (), getSourceType (),
310+ getOutputType (), " source" , " output" );
311+ }
312+
313+ LogicalResult
314+ GatherOp::reifyResultShapes (OpBuilder &b,
315+ ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
316+ return cast<LinalgExtOp>(getOperation ())
317+ .reifyResultShapes (b, reifiedReturnShapes);
318+ }
319+
288320// ===----------------------------------------------------------------------===//
289321// SortOp
290322// ===----------------------------------------------------------------------===//
@@ -1950,6 +1982,7 @@ LogicalResult IREE::LinalgExt::IndexOp::verify() {
19501982 }
19511983
19521984DEFINE_OP_GET_EFFECTS (ScatterOp)
1985+ DEFINE_OP_GET_EFFECTS (GatherOp)
19531986DEFINE_OP_GET_EFFECTS (SortOp)
19541987DEFINE_OP_GET_EFFECTS (FftOp)
19551988DEFINE_OP_GET_EFFECTS (ScanOp)
0 commit comments