@@ -1616,33 +1616,43 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
16161616 MLIRContext *context, ::std::optional<Location> location,
16171617 TileOp::Adaptor adaptor,
16181618 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1619- DenseIntElementsAttr multiplesAttr;
1620- if (!matchPattern (adaptor.getMultiples (), m_Constant (&multiplesAttr)))
1621- return failure ();
1622-
1623- SmallVector<int64_t > multiples = llvm::to_vector (
1624- llvm::map_range (multiplesAttr.getValues <APInt>(),
1625- [](const APInt &val) { return val.getSExtValue (); }));
1619+ Type inputType = getElementTypeOrSelf (adaptor.getInput1 ().getType ());
1620+ SmallVector<int64_t > multiples;
1621+ if (!tosa::getConstShapeValues (adaptor.getMultiples ().getDefiningOp (),
1622+ multiples)) {
1623+ auto rank =
1624+ cast<tosa::shapeType>(adaptor.getMultiples ().getType ()).getRank ();
1625+ SmallVector<int64_t > fallback (rank, ShapedType::kDynamic );
1626+ inferredReturnShapes.push_back (ShapedTypeComponents (fallback, inputType));
1627+ return success ();
1628+ } else {
1629+ multiples = convertToMlirShape (multiples);
1630+ }
16261631
16271632 ShapeAdaptor inputShape (adaptor.getInput1 ().getType ());
16281633 SmallVector<int64_t > outputShape;
16291634 if (!inputShape.hasRank ()) {
16301635 outputShape.resize (multiples.size (), ShapedType::kDynamic );
1631- inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
1636+ inferredReturnShapes.push_back (
1637+ ShapedTypeComponents (outputShape, inputType));
16321638 return success ();
16331639 } else if (static_cast <size_t >(inputShape.getRank ()) != multiples.size ())
16341640 return failure ();
16351641
16361642 // Any non dynamic dimension can be multiplied to a known size.
16371643 outputShape.reserve (multiples.size ());
16381644 for (int i = 0 , s = inputShape.getRank (); i < s; i++) {
1639- int64_t dim = inputShape.getDimSize (i);
1640- if (dim != ShapedType::kDynamic )
1641- dim *= multiples[i];
1642- outputShape.push_back (dim);
1645+ if (multiples[i] == ShapedType::kDynamic ) {
1646+ outputShape.push_back (ShapedType::kDynamic );
1647+ } else {
1648+ int64_t dim = inputShape.getDimSize (i);
1649+ if (dim != ShapedType::kDynamic )
1650+ dim *= multiples[i];
1651+ outputShape.push_back (dim);
1652+ }
16431653 }
16441654
1645- inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
1655+ inferredReturnShapes.push_back (ShapedTypeComponents (outputShape, inputType ));
16461656 return success ();
16471657}
16481658
0 commit comments