@@ -232,20 +232,6 @@ static llvm::Optional<py::object> getPyLiteral(mlir::Attribute attr) {
232
232
return {};
233
233
}
234
234
235
- static llvm::Optional<py::object> makePyLiteral (mlir::Value val) {
236
- assert (val);
237
- if (auto literal = val.getType ().dyn_cast <plier::LiteralType>())
238
- return getPyLiteral (literal.getValue ());
239
-
240
- if (auto cast = val.getDefiningOp <plier::SignCastOp>())
241
- val = cast.value ();
242
-
243
- if (auto attr = plier::getConstVal<mlir::Attribute>(val))
244
- return getPyLiteral (attr);
245
-
246
- return {};
247
- }
248
-
249
235
static mlir::Value doCast (mlir::OpBuilder &builder, mlir::Location loc,
250
236
mlir::Value val, mlir::Type type) {
251
237
if (val.getType () != type)
@@ -284,7 +270,7 @@ struct PyLinalgResolver::Context {
284
270
if (auto typevar = type.dyn_cast <plier::TypeVar>())
285
271
return createType (typevar.getType ());
286
272
287
- if (auto literal = makePyLiteral (value))
273
+ if (auto literal = makePyLiteral (context, value))
288
274
return *literal;
289
275
290
276
auto ret = var (context, wrapMlir (value));
@@ -400,6 +386,32 @@ struct PyLinalgResolver::Context {
400
386
401
387
return doCast (builder, loc, unwrapVal (loc, builder, obj), resultType);
402
388
}
389
+
390
+ private:
391
+ llvm::Optional<py::object> makePyLiteral (py::capsule context,
392
+ mlir::Value val) {
393
+ assert (val);
394
+ if (auto literal = val.getType ().dyn_cast <plier::LiteralType>())
395
+ return getPyLiteral (literal.getValue ());
396
+
397
+ if (auto buildTuple = val.getDefiningOp <plier::BuildTupleOp>()) {
398
+ auto args = buildTuple.args ();
399
+ auto count = static_cast <unsigned >(args.size ());
400
+ py::tuple ret (count);
401
+ for (auto i : llvm::seq (0u , count))
402
+ ret[i] = createVar (context, args[i]);
403
+
404
+ return ret;
405
+ }
406
+
407
+ if (auto cast = val.getDefiningOp <plier::SignCastOp>())
408
+ val = cast.value ();
409
+
410
+ if (auto attr = plier::getConstVal<mlir::Attribute>(val))
411
+ return getPyLiteral (attr);
412
+
413
+ return {};
414
+ }
403
415
};
404
416
405
417
namespace {
@@ -1653,13 +1665,19 @@ static py::object getitemImpl(py::capsule context, py::capsule ssaVal,
1653
1665
template <typename Op>
1654
1666
static mlir::Value binopFunc (mlir::Location loc, mlir::OpBuilder &builder,
1655
1667
mlir::Value lhs, mlir::Value rhs) {
1656
- return builder.create <Op>(loc, lhs, rhs);
1668
+ auto lhsVar = doSignCast (builder, loc, lhs);
1669
+ auto rhsVar = doSignCast (builder, loc, rhs);
1670
+ auto res = builder.create <Op>(loc, lhsVar, rhsVar);
1671
+ return doSignCast (builder, loc, res, lhs.getType ());
1657
1672
}
1658
1673
1659
1674
template <typename Op>
1660
1675
static mlir::Value rbinopFunc (mlir::Location loc, mlir::OpBuilder &builder,
1661
1676
mlir::Value lhs, mlir::Value rhs) {
1662
- return builder.create <Op>(loc, rhs, lhs);
1677
+ auto lhsVar = doSignCast (builder, loc, lhs);
1678
+ auto rhsVar = doSignCast (builder, loc, rhs);
1679
+ auto res = builder.create <Op>(loc, rhsVar, lhsVar);
1680
+ return doSignCast (builder, loc, res, lhs.getType ());
1663
1681
}
1664
1682
1665
1683
static mlir::Value binopFuncIdiv (mlir::Location loc, mlir::OpBuilder &builder,
@@ -1669,6 +1687,14 @@ static mlir::Value binopFuncIdiv(mlir::Location loc, mlir::OpBuilder &builder,
1669
1687
return builder.create <mlir::arith::DivFOp>(loc, lhsVar, rhsVar);
1670
1688
}
1671
1689
1690
+ static mlir::Value binopFFloorDiv (mlir::Location loc, mlir::OpBuilder &builder,
1691
+ mlir::Value lhs, mlir::Value rhs) {
1692
+ auto lhsVar = doCast (builder, loc, lhs, builder.getF64Type ());
1693
+ auto rhsVar = doCast (builder, loc, rhs, builder.getF64Type ());
1694
+ auto res = builder.create <mlir::arith::DivFOp>(loc, lhsVar, rhsVar);
1695
+ return builder.create <mlir::math::FloorOp>(loc, res);
1696
+ }
1697
+
1672
1698
template <mlir::arith::CmpIPredicate Pred>
1673
1699
static mlir::Value binopCmpI (mlir::Location loc, mlir::OpBuilder &builder,
1674
1700
mlir::Value lhs, mlir::Value rhs) {
@@ -1713,6 +1739,7 @@ static py::object binopImpl(py::capsule context, py::capsule ssaVal,
1713
1739
&rbinopFunc<mlir::arith::SubFOp>},
1714
1740
{" *" , &binopFunc<mlir::arith::MulIOp>, &binopFunc<mlir::arith::MulFOp>},
1715
1741
{" /" , &binopFuncIdiv, &binopFunc<mlir::arith::DivFOp>},
1742
+ {" //" , &binopFunc<mlir::arith::DivSIOp>, &binopFFloorDiv},
1716
1743
1717
1744
{" lt" , &binopCmpI<mlir::arith::CmpIPredicate::slt>,
1718
1745
&binopCmpF<mlir::arith::CmpFPredicate::OLT>},
0 commit comments