@@ -1374,54 +1374,37 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
13741374 return mlir::success ();
13751375}
13761376
1377- LogicalResult tosa::TransposeOp::getConstantPerms (SmallVector<int32_t > &perms) {
1378- // Perms must be constants.
1379- DenseIntElementsAttr permsAttr;
1380- if (!matchPattern (getPerms (), m_Constant (&permsAttr)))
1381- return failure ();
1382-
1383- perms.clear ();
1384- for (auto v : permsAttr.getValues <APInt>())
1385- perms.push_back (v.getSExtValue ());
1386-
1387- return success ();
1388- }
1389-
13901377LogicalResult tosa::TransposeOp::inferReturnTypeComponents (
13911378 MLIRContext *context, ::std::optional<Location> location,
13921379 TransposeOp::Adaptor adaptor,
13931380 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
13941381 ShapeAdaptor inputShape (adaptor.getInput1 ().getType ());
1395- ShapeAdaptor permsShape (adaptor.getPerms ().getType ());
1396-
1397- // We cannot infer anything from a rank-0 "permutation" tensor.
1398- if (permsShape.hasRank () && permsShape.getRank () == 0 )
1399- return failure ();
14001382
14011383 // If input rank and permutation length is unknown, the output rank is
14021384 // unknown.
1403- if (!inputShape.hasRank () || !permsShape.hasRank () ||
1404- permsShape.isDynamicDim (0 )) {
1385+ if (!inputShape.hasRank ()) {
14051386 inferredReturnShapes.push_back (ShapedTypeComponents ());
14061387 return success ();
14071388 }
14081389
1390+ const auto inputRank = inputShape.getRank ();
1391+
14091392 // This would imply the number of permutations does not match the rank of
14101393 // the input which is illegal.
1411- if (permsShape. getDimSize ( 0 ) != inputShape. getRank ( )) {
1394+ if (adaptor. getPerms (). size () != static_cast < size_t >(inputRank )) {
14121395 return failure ();
14131396 }
14141397
14151398 SmallVector<int64_t > outputShape;
14161399 // Rank-0 means no permutations matter.
1417- if (inputShape. getRank () == 0 ) {
1400+ if (inputRank == 0 ) {
14181401 inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
14191402 return success ();
14201403 }
14211404
14221405 // Check whether the input dimensions are all the same.
14231406 bool allTheSame = true ;
1424- for (int i = 1 , s = inputShape. getRank () ; i < s; i++) {
1407+ for (int i = 1 , s = inputRank ; i < s; i++) {
14251408 if (inputShape.getDimSize (0 ) != inputShape.getDimSize (i)) {
14261409 allTheSame = false ;
14271410 break ;
@@ -1431,34 +1414,21 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
14311414 // If all of the input dimensions are the same we don't care about the
14321415 // permutation.
14331416 if (allTheSame) {
1434- outputShape.resize (inputShape. getRank () , inputShape.getDimSize (0 ));
1417+ outputShape.resize (inputRank , inputShape.getDimSize (0 ));
14351418 inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
14361419 return success ();
14371420 }
14381421
1439- outputShape.resize (inputShape.getRank (), ShapedType::kDynamic );
1440- // If the permuations are a constant we can directly determine the output
1441- // shape.
1442- DenseIntElementsAttr attr;
1443- if (matchPattern (adaptor.getPerms (), m_Constant (&attr)) &&
1444- attr.getType ().getRank () == 1 ) {
1445- ShapeAdaptor permShape = attr;
1446- // Constant permutation must be the same length as the input rank.
1447- if (inputShape.getRank () != permShape.getRank ())
1448- return emitOptionalError (location,
1449- " constant permutation must be the same length"
1450- " as the input rank" );
1451-
1452- // Constant permutation values must be within the input rank.
1453- for (int i = 0 , e = inputShape.getRank (); i < e; i++) {
1454- if (inputShape.getRank () <= permShape.getDimSize (i))
1455- return failure ();
1456- }
1422+ outputShape.resize (inputRank, ShapedType::kDynamic );
14571423
1458- outputShape.reserve (inputShape.getRank ());
1459- for (int i = 0 , s = inputShape.getRank (); i < s; i++) {
1460- outputShape[i] = inputShape.getDimSize (permShape.getDimSize (i));
1461- }
1424+ // Constant permutation values must be within the input rank.
1425+ if (llvm::any_of (adaptor.getPerms (),
1426+ [inputRank](const auto i) { return i >= inputRank; }))
1427+ return failure ();
1428+
1429+ outputShape.reserve (inputRank);
1430+ for (int i = 0 , s = inputRank; i < s; i++) {
1431+ outputShape[i] = inputShape.getDimSize (adaptor.getPerms ()[i]);
14621432 }
14631433
14641434 inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
@@ -1467,75 +1437,60 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
14671437
14681438LogicalResult tosa::TransposeOp::verify () {
14691439 TensorType inputType = getInput1 ().getType ();
1470- TensorType permType = getPerms ().getType ();
14711440 TensorType outputType = getOutput ().getType ();
1441+ const llvm::ArrayRef<int32_t > constantPerms = getPerms ();
14721442
1473- if (permType.hasRank () && permType.getRank () != 1 )
1474- return emitOpError ()
1475- << " expected permutation tensor to be rank 1 but got rank "
1476- << permType.getRank ();
1477- if (inputType.hasRank () && permType.hasRank ())
1478- if (!permType.isDynamicDim (0 ) &&
1479- permType.getDimSize (0 ) != inputType.getRank ())
1480- return emitOpError () << " expected permutation tensor dim 0 to have size "
1481- << inputType.getRank ()
1482- << " (input rank) but got size "
1483- << permType.getDimSize (0 );
1443+ if (inputType.hasRank () &&
1444+ constantPerms.size () != static_cast <size_t >(inputType.getRank ()))
1445+ return emitOpError () << " expected perms attribute to have size "
1446+ << inputType.getRank () << " (input rank) but got size "
1447+ << constantPerms.size ();
14841448 if (inputType.hasRank () && outputType.hasRank () &&
14851449 inputType.getRank () != outputType.getRank ())
14861450 return emitOpError ()
14871451 << " expected input tensor rank to equal result tensor rank" ;
1488- if (outputType.hasRank () && permType.hasRank ())
1489- if (!permType.isDynamicDim (0 ) &&
1490- permType.getDimSize (0 ) != outputType.getRank ())
1491- return emitOpError () << " expected permutation tensor dim 0 to have size "
1492- << outputType.getRank ()
1493- << " (output rank) but got size "
1494- << permType.getDimSize (0 );
1495-
1496- SmallVector<int32_t > constantPerms;
1497- if (succeeded (getConstantPerms (constantPerms))) {
1498- // Assert that the permutation tensor has a rank, which means that the
1499- // rank has been verified above.
1500- assert (permType.hasRank () &&
1501- " Unexpectedly found permutation tensor without rank" );
1502- if (!llvm::all_of (constantPerms,
1503- [&constantPerms](int32_t s) {
1504- return s >= 0 &&
1505- static_cast <size_t >(s) < constantPerms.size ();
1506- }) ||
1507- !isPermutationVector (llvm::to_vector (llvm::map_range (
1508- constantPerms, [](int32_t v) -> int64_t { return v; }))))
1509- return emitOpError () << " expected valid permutation tensor" ;
1510-
1511- // Verify that the types of the input and output tensors are properly
1512- // permuted.
1513- if (inputType.hasRank () && outputType.hasRank ()) {
1514- assert (constantPerms.size () == static_cast <size_t >(inputType.getRank ()) &&
1515- inputType.getRank () == outputType.getRank ());
1516-
1517- for (auto i = 0 ; i < outputType.getRank (); i++) {
1518- if (inputType.isDynamicDim (constantPerms[i]) ||
1519- outputType.isDynamicDim (i))
1520- continue ;
1521-
1522- if (inputType.getDimSize (constantPerms[i]) != outputType.getDimSize (i))
1523- return emitOpError ()
1524- << " expected output tensor dim " << i << " to match "
1525- << " input dim " << constantPerms[i] << " with value of "
1526- << inputType.getDimSize (constantPerms[i]);
1527- }
1452+ if (outputType.hasRank () &&
1453+ constantPerms.size () != static_cast <size_t >(outputType.getRank ()))
1454+ return emitOpError () << " expected perms attribute to have size "
1455+ << outputType.getRank ()
1456+ << " (output rank) but got size "
1457+ << constantPerms.size ();
1458+
1459+ if (!llvm::all_of (constantPerms,
1460+ [&constantPerms](int32_t s) {
1461+ return s >= 0 &&
1462+ static_cast <size_t >(s) < constantPerms.size ();
1463+ }) ||
1464+ !isPermutationVector (llvm::to_vector (llvm::map_range (
1465+ constantPerms, [](int32_t v) -> int64_t { return v; }))))
1466+ return emitOpError () << " expected valid permutation indices" ;
1467+
1468+ // Verify that the types of the input and output tensors are properly
1469+ // permuted.
1470+ if (inputType.hasRank () && outputType.hasRank ()) {
1471+ assert (constantPerms.size () == static_cast <size_t >(inputType.getRank ()) &&
1472+ inputType.getRank () == outputType.getRank ());
1473+
1474+ for (auto i = 0 ; i < outputType.getRank (); i++) {
1475+ if (inputType.isDynamicDim (constantPerms[i]) ||
1476+ outputType.isDynamicDim (i))
1477+ continue ;
1478+
1479+ if (inputType.getDimSize (constantPerms[i]) != outputType.getDimSize (i))
1480+ return emitOpError ()
1481+ << " expected output tensor dim " << i << " to match "
1482+ << " input dim " << constantPerms[i] << " with value of "
1483+ << inputType.getDimSize (constantPerms[i]);
15281484 }
15291485 }
1486+
15301487 return success ();
15311488}
15321489
15331490LogicalResult TransposeOp::reifyResultShapes (
15341491 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
15351492
1536- SmallVector<int32_t > transposePerms;
1537- if (getConstantPerms (transposePerms).failed ())
1538- return failure ();
1493+ const llvm::ArrayRef<int32_t > transposePerms = getPerms ();
15391494
15401495 Value input = getInput1 ();
15411496 auto inputType = cast<TensorType>(input.getType ());
0 commit comments