@@ -1374,41 +1374,22 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
13741374 return mlir::success ();
13751375}
13761376
1377- LogicalResult tosa::TransposeOp::getConstantPerms (SmallVector<int32_t > &perms) {
1378- // Perms must be constants.
1379- DenseIntElementsAttr permsAttr;
1380- if (!matchPattern (getPerms (), m_Constant (&permsAttr)))
1381- return failure ();
1382-
1383- perms.clear ();
1384- for (auto v : permsAttr.getValues <APInt>())
1385- perms.push_back (v.getSExtValue ());
1386-
1387- return success ();
1388- }
1389-
13901377LogicalResult tosa::TransposeOp::inferReturnTypeComponents (
13911378 MLIRContext *context, ::std::optional<Location> location,
13921379 TransposeOp::Adaptor adaptor,
13931380 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
13941381 ShapeAdaptor inputShape (adaptor.getInput1 ().getType ());
1395- ShapeAdaptor permsShape (adaptor.getPerms ().getType ());
1396-
1397- // We cannot infer anything from a rank-0 "permutation" tensor.
1398- if (permsShape.hasRank () && permsShape.getRank () == 0 )
1399- return failure ();
14001382
14011383 // If input rank and permutation length is unknown, the output rank is
14021384 // unknown.
1403- if (!inputShape.hasRank () || !permsShape.hasRank () ||
1404- permsShape.isDynamicDim (0 )) {
1385+ if (!inputShape.hasRank ()) {
14051386 inferredReturnShapes.push_back (ShapedTypeComponents ());
14061387 return success ();
14071388 }
14081389
14091390 // This would imply the number of permutations does not match the rank of
14101391 // the input which is illegal.
1411- if (permsShape. getDimSize ( 0 ) != inputShape.getRank ()) {
1392+ if (adaptor. getPerms (). size () != static_cast < size_t >( inputShape.getRank () )) {
14121393 return failure ();
14131394 }
14141395
@@ -1437,28 +1418,16 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
14371418 }
14381419
14391420 outputShape.resize (inputShape.getRank (), ShapedType::kDynamic );
1440- // If the permuations are a constant we can directly determine the output
1441- // shape.
1442- DenseIntElementsAttr attr;
1443- if (matchPattern (adaptor.getPerms (), m_Constant (&attr)) &&
1444- attr.getType ().getRank () == 1 ) {
1445- ShapeAdaptor permShape = attr;
1446- // Constant permutation must be the same length as the input rank.
1447- if (inputShape.getRank () != permShape.getRank ())
1448- return emitOptionalError (location,
1449- " constant permutation must be the same length"
1450- " as the input rank" );
1451-
1452- // Constant permutation values must be within the input rank.
1453- for (int i = 0 , e = inputShape.getRank (); i < e; i++) {
1454- if (inputShape.getRank () <= permShape.getDimSize (i))
1455- return failure ();
1456- }
14571421
1458- outputShape.reserve (inputShape.getRank ());
1459- for (int i = 0 , s = inputShape.getRank (); i < s; i++) {
1460- outputShape[i] = inputShape.getDimSize (permShape.getDimSize (i));
1461- }
1422+ // Constant permutation values must be within the input rank.
1423+ for (auto i : adaptor.getPerms ()) {
1424+ if (inputShape.getRank () <= i)
1425+ return failure ();
1426+ }
1427+
1428+ outputShape.reserve (inputShape.getRank ());
1429+ for (int i = 0 , s = inputShape.getRank (); i < s; i++) {
1430+ outputShape[i] = inputShape.getDimSize (adaptor.getPerms ()[i]);
14621431 }
14631432
14641433 inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
@@ -1467,75 +1436,61 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
14671436
14681437LogicalResult tosa::TransposeOp::verify () {
14691438 TensorType inputType = getInput1 ().getType ();
1470- TensorType permType = getPerms ().getType ();
14711439 TensorType outputType = getOutput ().getType ();
1440+ const llvm::ArrayRef<int32_t > constantPerms = getPerms ();
14721441
1473- if (permType.hasRank () && permType.getRank () != 1 )
1474- return emitOpError ()
1475- << " expected permutation tensor to be rank 1 but got rank "
1476- << permType.getRank ();
1477- if (inputType.hasRank () && permType.hasRank ())
1478- if (!permType.isDynamicDim (0 ) &&
1479- permType.getDimSize (0 ) != inputType.getRank ())
1480- return emitOpError () << " expected permutation tensor dim 0 to have size "
1442+ if (inputType.hasRank ())
1443+ if (constantPerms.size () != static_cast <size_t >(inputType.getRank ()))
1444+ return emitOpError () << " expected perms attribute to have size "
14811445 << inputType.getRank ()
14821446 << " (input rank) but got size "
1483- << permType. getDimSize ( 0 );
1447+ << constantPerms. size ( );
14841448 if (inputType.hasRank () && outputType.hasRank () &&
14851449 inputType.getRank () != outputType.getRank ())
14861450 return emitOpError ()
14871451 << " expected input tensor rank to equal result tensor rank" ;
1488- if (outputType.hasRank () && permType.hasRank ())
1489- if (!permType.isDynamicDim (0 ) &&
1490- permType.getDimSize (0 ) != outputType.getRank ())
1491- return emitOpError () << " expected permutation tensor dim 0 to have size "
1452+ if (outputType.hasRank ())
1453+ if (constantPerms.size () != static_cast <size_t >(outputType.getRank ()))
1454+ return emitOpError () << " expected perms attribute to have size "
14921455 << outputType.getRank ()
14931456 << " (output rank) but got size "
1494- << permType.getDimSize (0 );
1495-
1496- SmallVector<int32_t > constantPerms;
1497- if (succeeded (getConstantPerms (constantPerms))) {
1498- // Assert that the permutation tensor has a rank, which means that the
1499- // rank has been verified above.
1500- assert (permType.hasRank () &&
1501- " Unexpectedly found permutation tensor without rank" );
1502- if (!llvm::all_of (constantPerms,
1503- [&constantPerms](int32_t s) {
1504- return s >= 0 &&
1505- static_cast <size_t >(s) < constantPerms.size ();
1506- }) ||
1507- !isPermutationVector (llvm::to_vector (llvm::map_range (
1508- constantPerms, [](int32_t v) -> int64_t { return v; }))))
1509- return emitOpError () << " expected valid permutation tensor" ;
1510-
1511- // Verify that the types of the input and output tensors are properly
1512- // permuted.
1513- if (inputType.hasRank () && outputType.hasRank ()) {
1514- assert (constantPerms.size () == static_cast <size_t >(inputType.getRank ()) &&
1515- inputType.getRank () == outputType.getRank ());
1516-
1517- for (auto i = 0 ; i < outputType.getRank (); i++) {
1518- if (inputType.isDynamicDim (constantPerms[i]) ||
1519- outputType.isDynamicDim (i))
1520- continue ;
1521-
1522- if (inputType.getDimSize (constantPerms[i]) != outputType.getDimSize (i))
1523- return emitOpError ()
1524- << " expected output tensor dim " << i << " to match "
1525- << " input dim " << constantPerms[i] << " with value of "
1526- << inputType.getDimSize (constantPerms[i]);
1527- }
1457+ << constantPerms.size ();
1458+
1459+ if (!llvm::all_of (constantPerms,
1460+ [&constantPerms](int32_t s) {
1461+ return s >= 0 &&
1462+ static_cast <size_t >(s) < constantPerms.size ();
1463+ }) ||
1464+ !isPermutationVector (llvm::to_vector (llvm::map_range (
1465+ constantPerms, [](int32_t v) -> int64_t { return v; }))))
1466+ return emitOpError () << " expected valid permutation indices" ;
1467+
1468+ // Verify that the types of the input and output tensors are properly
1469+ // permuted.
1470+ if (inputType.hasRank () && outputType.hasRank ()) {
1471+ assert (constantPerms.size () == static_cast <size_t >(inputType.getRank ()) &&
1472+ inputType.getRank () == outputType.getRank ());
1473+
1474+ for (auto i = 0 ; i < outputType.getRank (); i++) {
1475+ if (inputType.isDynamicDim (constantPerms[i]) ||
1476+ outputType.isDynamicDim (i))
1477+ continue ;
1478+
1479+ if (inputType.getDimSize (constantPerms[i]) != outputType.getDimSize (i))
1480+ return emitOpError ()
1481+ << " expected output tensor dim " << i << " to match "
1482+ << " input dim " << constantPerms[i] << " with value of "
1483+ << inputType.getDimSize (constantPerms[i]);
15281484 }
15291485 }
1486+
15301487 return success ();
15311488}
15321489
15331490LogicalResult TransposeOp::reifyResultShapes (
15341491 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
15351492
1536- SmallVector<int32_t > transposePerms;
1537- if (getConstantPerms (transposePerms).failed ())
1538- return failure ();
1493+ const llvm::ArrayRef<int32_t > transposePerms = getPerms ();
15391494
15401495 Value input = getInput1 ();
15411496 auto inputType = cast<TensorType>(input.getType ());
0 commit comments