@@ -1616,6 +1616,67 @@ struct LowerTupleCasts : public mlir::OpConversionPattern<plier::CastOp> {
1616
1616
}
1617
1617
};
1618
1618
1619
+ static mlir::Value convertTensorElements (mlir::OpBuilder &builder,
1620
+ mlir::Location loc, mlir::Value src,
1621
+ mlir::ShapedType dstType,
1622
+ mlir::Type origSrcElemType,
1623
+ mlir::Type origDstElemType) {
1624
+ assert (src.getType ().isa <mlir::ShapedType>());
1625
+ auto srcType = src.getType ().cast <mlir::ShapedType>();
1626
+ assert (srcType.getRank () == dstType.getRank ());
1627
+ if (srcType.getElementType () == dstType.getElementType ())
1628
+ return src;
1629
+
1630
+ if (srcType.isa <mlir::MemRefType>())
1631
+ src = builder.create <mlir::bufferization::ToTensorOp>(loc, src);
1632
+
1633
+ auto rank = static_cast <unsigned >(srcType.getRank ());
1634
+ llvm::SmallVector<mlir::Value> shape (rank);
1635
+ for (auto i : llvm::seq (0u , rank))
1636
+ shape[i] = builder.create <mlir::tensor::DimOp>(loc, src, i);
1637
+
1638
+ mlir::Value init = builder.create <mlir::linalg::InitTensorOp>(
1639
+ loc, shape, dstType.getElementType ());
1640
+
1641
+ auto affineMap =
1642
+ mlir::AffineMap::getMultiDimIdentityMap (rank, builder.getContext ());
1643
+ const mlir::AffineMap maps[] = {
1644
+ affineMap,
1645
+ affineMap,
1646
+ };
1647
+
1648
+ llvm::SmallVector<mlir::StringRef> iterators (rank, " parallel" );
1649
+
1650
+ auto bodyBuilder = [&](mlir::OpBuilder &b, mlir::Location l,
1651
+ mlir::ValueRange args) {
1652
+ assert (args.size () == 2 );
1653
+ auto doSignCast = [&](mlir::Value src, mlir::Type dstType) -> mlir::Value {
1654
+ if (src.getType () != dstType)
1655
+ return b.create <plier::SignCastOp>(l, dstType, src);
1656
+
1657
+ return src;
1658
+ };
1659
+ auto arg = doSignCast (args.front (), origSrcElemType);
1660
+ arg = b.create <plier::CastOp>(l, origDstElemType, arg);
1661
+ arg = doSignCast (arg, dstType.getElementType ());
1662
+ b.create <mlir::linalg::YieldOp>(l, arg);
1663
+ };
1664
+ mlir::Value res =
1665
+ builder
1666
+ .create <mlir::linalg::GenericOp>(loc, init.getType (), src, init, maps,
1667
+ iterators, bodyBuilder)
1668
+ .getResult (0 );
1669
+
1670
+ if (dstType.isa <mlir::MemRefType>()) {
1671
+ auto memrefType =
1672
+ mlir::MemRefType::get (dstType.getShape (), dstType.getElementType ());
1673
+ res = builder.create <mlir::bufferization::ToMemrefOp>(loc, memrefType, res);
1674
+ }
1675
+
1676
+ rerunScfPipeline (res.getDefiningOp ());
1677
+ return res;
1678
+ }
1679
+
1619
1680
struct LowerTensorCasts : public mlir ::OpConversionPattern<plier::CastOp> {
1620
1681
using OpConversionPattern::OpConversionPattern;
1621
1682
@@ -1653,6 +1714,9 @@ struct LowerTensorCasts : public mlir::OpConversionPattern<plier::CastOp> {
1653
1714
value =
1654
1715
rewriter.createOrFold <plier::SignCastOp>(loc, signlessSrcType, value);
1655
1716
1717
+ value = convertTensorElements (rewriter, loc, value, signlessDstType,
1718
+ srcElem, dstElem);
1719
+
1656
1720
bool isSrcMemref = srcType.isa <mlir::MemRefType>();
1657
1721
bool isDstMemref = dstType.isa <mlir::MemRefType>();
1658
1722
0 commit comments