@@ -2262,8 +2262,52 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
22622262}
22632263
22642264LogicalResult tosa::GatherOp::verify () {
2265- return verifySameElementTypes (*this , /* inType = */ getValues ().getType (),
2266- /* outType = */ getOutput ().getType ());
2265+ if (verifySameElementTypes (*this , /* inType = */ getValues ().getType (),
2266+ /* outType = */ getOutput ().getType ())
2267+ .failed ()) {
2268+ return failure ();
2269+ }
2270+
2271+ const ShapeAdaptor valuesShape (getValues ().getType ());
2272+ const ShapeAdaptor indicesShape (getIndices ().getType ());
2273+ const ShapeAdaptor outputShape (getOutput ().getType ());
2274+
2275+ int64_t N = ShapedType::kDynamic ;
2276+ int64_t W = ShapedType::kDynamic ;
2277+ int64_t C = ShapedType::kDynamic ;
2278+
2279+ if (valuesShape.hasRank ()) {
2280+ N = valuesShape.getDimSize (0 );
2281+ C = valuesShape.getDimSize (2 );
2282+ }
2283+ if (indicesShape.hasRank ()) {
2284+ const int64_t indicesN = indicesShape.getDimSize (0 );
2285+ W = indicesShape.getDimSize (1 );
2286+ if (N == ShapedType::kDynamic )
2287+ N = indicesN;
2288+ else if (indicesN != ShapedType::kDynamic && N != indicesN)
2289+ return emitOpError () << " requires indices dimension 0 to have size " << N
2290+ << " , got " << indicesN;
2291+ }
2292+ if (outputShape.hasRank ()) {
2293+ const int64_t outputN = outputShape.getDimSize (0 );
2294+ const int64_t outputW = outputShape.getDimSize (1 );
2295+ const int64_t outputC = outputShape.getDimSize (2 );
2296+ if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2297+ N != outputN)
2298+ return emitOpError () << " requires output dimension 0 to have size " << N
2299+ << " , got " << outputN;
2300+
2301+ if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2302+ W != outputW)
2303+ return emitOpError () << " requires output dimension 1 to have size " << W
2304+ << " , got " << outputW;
2305+ if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2306+ C != outputC)
2307+ return emitOpError () << " requires output dimension 2 to have size " << C
2308+ << " , got " << outputC;
2309+ }
2310+ return success ();
22672311}
22682312
22692313LogicalResult tosa::ResizeOp::inferReturnTypeComponents (
0 commit comments