@@ -1372,54 +1372,37 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
13721372 return mlir::success ();
13731373}
13741374
1375- LogicalResult tosa::TransposeOp::getConstantPerms (SmallVector<int32_t > &perms) {
1376- // Perms must be constants.
1377- DenseIntElementsAttr permsAttr;
1378- if (!matchPattern (getPerms (), m_Constant (&permsAttr)))
1379- return failure ();
1380-
1381- perms.clear ();
1382- for (auto v : permsAttr.getValues <APInt>())
1383- perms.push_back (v.getSExtValue ());
1384-
1385- return success ();
1386- }
1387-
13881375LogicalResult tosa::TransposeOp::inferReturnTypeComponents (
13891376 MLIRContext *context, ::std::optional<Location> location,
13901377 TransposeOp::Adaptor adaptor,
13911378 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
13921379 ShapeAdaptor inputShape (adaptor.getInput1 ().getType ());
1393- ShapeAdaptor permsShape (adaptor.getPerms ().getType ());
1394-
1395- // We cannot infer anything from a rank-0 "permutation" tensor.
1396- if (permsShape.hasRank () && permsShape.getRank () == 0 )
1397- return failure ();
13981380
13991381 // If input rank and permutation length is unknown, the output rank is
14001382 // unknown.
1401- if (!inputShape.hasRank () || !permsShape.hasRank () ||
1402- permsShape.isDynamicDim (0 )) {
1383+ if (!inputShape.hasRank ()) {
14031384 inferredReturnShapes.push_back (ShapedTypeComponents ());
14041385 return success ();
14051386 }
14061387
1388+ const auto inputRank = inputShape.getRank ();
1389+
14071390 // This would imply the number of permutations does not match the rank of
14081391 // the input which is illegal.
1409- if (permsShape. getDimSize ( 0 ) != inputShape. getRank ( )) {
1392+ if (adaptor. getPerms (). size () != static_cast < size_t >(inputRank )) {
14101393 return failure ();
14111394 }
14121395
14131396 SmallVector<int64_t > outputShape;
14141397 // Rank-0 means no permutations matter.
1415- if (inputShape. getRank () == 0 ) {
1398+ if (inputRank == 0 ) {
14161399 inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
14171400 return success ();
14181401 }
14191402
14201403 // Check whether the input dimensions are all the same.
14211404 bool allTheSame = true ;
1422- for (int i = 1 , s = inputShape. getRank () ; i < s; i++) {
1405+ for (int i = 1 , s = inputRank ; i < s; i++) {
14231406 if (inputShape.getDimSize (0 ) != inputShape.getDimSize (i)) {
14241407 allTheSame = false ;
14251408 break ;
@@ -1429,34 +1412,21 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
14291412 // If all of the input dimensions are the same we don't care about the
14301413 // permutation.
14311414 if (allTheSame) {
1432- outputShape.resize (inputShape. getRank () , inputShape.getDimSize (0 ));
1415+ outputShape.resize (inputRank , inputShape.getDimSize (0 ));
14331416 inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
14341417 return success ();
14351418 }
14361419
1437- outputShape.resize (inputShape.getRank (), ShapedType::kDynamic );
1438- // If the permuations are a constant we can directly determine the output
1439- // shape.
1440- DenseIntElementsAttr attr;
1441- if (matchPattern (adaptor.getPerms (), m_Constant (&attr)) &&
1442- attr.getType ().getRank () == 1 ) {
1443- ShapeAdaptor permShape = attr;
1444- // Constant permutation must be the same length as the input rank.
1445- if (inputShape.getRank () != permShape.getRank ())
1446- return emitOptionalError (location,
1447- " constant permutation must be the same length"
1448- " as the input rank" );
1449-
1450- // Constant permutation values must be within the input rank.
1451- for (int i = 0 , e = inputShape.getRank (); i < e; i++) {
1452- if (inputShape.getRank () <= permShape.getDimSize (i))
1453- return failure ();
1454- }
1420+ outputShape.resize (inputRank, ShapedType::kDynamic );
14551421
1456- outputShape.reserve (inputShape.getRank ());
1457- for (int i = 0 , s = inputShape.getRank (); i < s; i++) {
1458- outputShape[i] = inputShape.getDimSize (permShape.getDimSize (i));
1459- }
1422+ // Constant permutation values must be within the input rank.
1423+ if (llvm::any_of (adaptor.getPerms (),
1424+ [inputRank](const auto i) { return i >= inputRank; }))
1425+ return failure ();
1426+
1427+ outputShape.reserve (inputRank);
1428+ for (int i = 0 , s = inputRank; i < s; i++) {
1429+ outputShape[i] = inputShape.getDimSize (adaptor.getPerms ()[i]);
14601430 }
14611431
14621432 inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
@@ -1465,75 +1435,60 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
14651435
14661436LogicalResult tosa::TransposeOp::verify () {
14671437 TensorType inputType = getInput1 ().getType ();
1468- TensorType permType = getPerms ().getType ();
14691438 TensorType outputType = getOutput ().getType ();
1439+ const llvm::ArrayRef<int32_t > constantPerms = getPerms ();
14701440
1471- if (permType.hasRank () && permType.getRank () != 1 )
1472- return emitOpError ()
1473- << " expected permutation tensor to be rank 1 but got rank "
1474- << permType.getRank ();
1475- if (inputType.hasRank () && permType.hasRank ())
1476- if (!permType.isDynamicDim (0 ) &&
1477- permType.getDimSize (0 ) != inputType.getRank ())
1478- return emitOpError () << " expected permutation tensor dim 0 to have size "
1479- << inputType.getRank ()
1480- << " (input rank) but got size "
1481- << permType.getDimSize (0 );
1441+ if (inputType.hasRank () &&
1442+ constantPerms.size () != static_cast <size_t >(inputType.getRank ()))
1443+ return emitOpError () << " expected perms attribute to have size "
1444+ << inputType.getRank () << " (input rank) but got size "
1445+ << constantPerms.size ();
14821446 if (inputType.hasRank () && outputType.hasRank () &&
14831447 inputType.getRank () != outputType.getRank ())
14841448 return emitOpError ()
14851449 << " expected input tensor rank to equal result tensor rank" ;
1486- if (outputType.hasRank () && permType.hasRank ())
1487- if (!permType.isDynamicDim (0 ) &&
1488- permType.getDimSize (0 ) != outputType.getRank ())
1489- return emitOpError () << " expected permutation tensor dim 0 to have size "
1490- << outputType.getRank ()
1491- << " (output rank) but got size "
1492- << permType.getDimSize (0 );
1493-
1494- SmallVector<int32_t > constantPerms;
1495- if (succeeded (getConstantPerms (constantPerms))) {
1496- // Assert that the permutation tensor has a rank, which means that the
1497- // rank has been verified above.
1498- assert (permType.hasRank () &&
1499- " Unexpectedly found permutation tensor without rank" );
1500- if (!llvm::all_of (constantPerms,
1501- [&constantPerms](int32_t s) {
1502- return s >= 0 &&
1503- static_cast <size_t >(s) < constantPerms.size ();
1504- }) ||
1505- !isPermutationVector (llvm::to_vector (llvm::map_range (
1506- constantPerms, [](int32_t v) -> int64_t { return v; }))))
1507- return emitOpError () << " expected valid permutation tensor" ;
1508-
1509- // Verify that the types of the input and output tensors are properly
1510- // permuted.
1511- if (inputType.hasRank () && outputType.hasRank ()) {
1512- assert (constantPerms.size () == static_cast <size_t >(inputType.getRank ()) &&
1513- inputType.getRank () == outputType.getRank ());
1514-
1515- for (auto i = 0 ; i < outputType.getRank (); i++) {
1516- if (inputType.isDynamicDim (constantPerms[i]) ||
1517- outputType.isDynamicDim (i))
1518- continue ;
1519-
1520- if (inputType.getDimSize (constantPerms[i]) != outputType.getDimSize (i))
1521- return emitOpError ()
1522- << " expected output tensor dim " << i << " to match "
1523- << " input dim " << constantPerms[i] << " with value of "
1524- << inputType.getDimSize (constantPerms[i]);
1525- }
1450+ if (outputType.hasRank () &&
1451+ constantPerms.size () != static_cast <size_t >(outputType.getRank ()))
1452+ return emitOpError () << " expected perms attribute to have size "
1453+ << outputType.getRank ()
1454+ << " (output rank) but got size "
1455+ << constantPerms.size ();
1456+
1457+ if (!llvm::all_of (constantPerms,
1458+ [&constantPerms](int32_t s) {
1459+ return s >= 0 &&
1460+ static_cast <size_t >(s) < constantPerms.size ();
1461+ }) ||
1462+ !isPermutationVector (llvm::to_vector (llvm::map_range (
1463+ constantPerms, [](int32_t v) -> int64_t { return v; }))))
1464+ return emitOpError () << " expected valid permutation indices" ;
1465+
1466+ // Verify that the types of the input and output tensors are properly
1467+ // permuted.
1468+ if (inputType.hasRank () && outputType.hasRank ()) {
1469+ assert (constantPerms.size () == static_cast <size_t >(inputType.getRank ()) &&
1470+ inputType.getRank () == outputType.getRank ());
1471+
1472+ for (auto i = 0 ; i < outputType.getRank (); i++) {
1473+ if (inputType.isDynamicDim (constantPerms[i]) ||
1474+ outputType.isDynamicDim (i))
1475+ continue ;
1476+
1477+ if (inputType.getDimSize (constantPerms[i]) != outputType.getDimSize (i))
1478+ return emitOpError ()
1479+ << " expected output tensor dim " << i << " to match "
1480+ << " input dim " << constantPerms[i] << " with value of "
1481+ << inputType.getDimSize (constantPerms[i]);
15261482 }
15271483 }
1484+
15281485 return success ();
15291486}
15301487
15311488LogicalResult TransposeOp::reifyResultShapes (
15321489 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
15331490
1534- SmallVector<int32_t > transposePerms;
1535- if (getConstantPerms (transposePerms).failed ())
1536- return failure ();
1491+ const llvm::ArrayRef<int32_t > transposePerms = getPerms ();
15371492
15381493 Value input = getInput1 ();
15391494 auto inputType = cast<TensorType>(input.getType ());
0 commit comments