@@ -1535,6 +1535,24 @@ void AtenRsubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
15351535 });
15361536}
15371537
1538+ // ===----------------------------------------------------------------------===//
1539+ // AtenRSubScalarOp
1540+ // ===----------------------------------------------------------------------===//
1541+
1542+ OpFoldResult AtenRsubScalarOp::fold (FoldAdaptor adaptor) {
1543+ auto fpFold = [](llvm::ArrayRef<double > inputs) {
1544+ assert (inputs.size () == 3 );
1545+ return inputs[1 ] - inputs[0 ] * inputs[2 ];
1546+ };
1547+
1548+ auto intFold = [](llvm::ArrayRef<APInt> inputs) {
1549+ assert (inputs.size () == 3 );
1550+ return inputs[1 ] - inputs[0 ] * inputs[2 ];
1551+ };
1552+
1553+ return naryFolderHelper (adaptor.getOperands (), getType (), fpFold, intFold);
1554+ }
1555+
15381556// ===----------------------------------------------------------------------===//
15391557// AtenMulTensorOp
15401558// ===----------------------------------------------------------------------===//
@@ -1979,6 +1997,58 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns(
19791997 });
19801998}
19811999
2000+ // ===----------------------------------------------------------------------===//
2001+ // AtenDivTensorModeOp
2002+ // ===----------------------------------------------------------------------===//
2003+
2004+ OpFoldResult AtenDivTensorModeOp::fold (FoldAdaptor adaptor) {
2005+ auto resultTy = dyn_cast_or_null<ValueTensorType>(getType ());
2006+ if (!resultTy || !resultTy.hasDtype ()) {
2007+ return nullptr ;
2008+ }
2009+ std::function<double (ArrayRef<double >)> fpFold;
2010+ std::function<APInt (ArrayRef<APInt>)> intFold;
2011+
2012+ auto roundMode = dyn_cast_or_null<StringAttr>(adaptor.getRoundingMode ());
2013+ auto unsign = false ;
2014+ if (isa<mlir::IntegerType>(resultTy.getDtype ())) {
2015+ unsign = cast<IntegerType>(resultTy.getDtype ()).isUnsigned ();
2016+ }
2017+
2018+ fpFold = [roundMode](llvm::ArrayRef<double > inputs) {
2019+ assert (inputs.size () == 2 );
2020+ if (!roundMode) {
2021+ return (double )inputs[0 ] / inputs[1 ];
2022+ } else if (roundMode.getValue ().str () == " floor" ) {
2023+ return std::floor ((double )inputs[0 ] / inputs[1 ]);
2024+ } else {
2025+ return std::trunc ((double )inputs[0 ] / inputs[1 ]);
2026+ }
2027+ };
2028+
2029+ intFold = [unsign, roundMode](llvm::ArrayRef<APInt> inputs) {
2030+ assert (inputs.size () == 2 );
2031+ auto lhs = unsign ? inputs[0 ].getZExtValue () : inputs[0 ].getSExtValue ();
2032+ auto rhs = unsign ? inputs[1 ].getZExtValue () : inputs[1 ].getSExtValue ();
2033+ int64_t bits = std::max (inputs[0 ].getBitWidth (), inputs[1 ].getBitWidth ());
2034+ int64_t res;
2035+ if (roundMode.getValue ().str () == " floor" ) {
2036+ res = std::floor (lhs / rhs);
2037+ } else {
2038+ res = std::trunc (lhs / rhs);
2039+ }
2040+ return APInt (bits, res);
2041+ };
2042+
2043+ if (!roundMode) {
2044+ return naryFolderHelper ({adaptor.getSelf (), adaptor.getOther ()}, getType (),
2045+ fpFold, std::nullopt );
2046+ }
2047+
2048+ return naryFolderHelper ({adaptor.getSelf (), adaptor.getOther ()}, getType (),
2049+ fpFold, intFold);
2050+ }
2051+
19822052// ===----------------------------------------------------------------------===//
19832053// AtenDivScalarModeOp
19842054// ===----------------------------------------------------------------------===//
@@ -3612,6 +3682,34 @@ OpFoldResult AtenRemainderIntOp::fold(FoldAdaptor adaptor) {
36123682 adaptor.getOperands (), [](int64_t a, int64_t b) { return a % b; });
36133683}
36143684
3685+ // ===----------------------------------------------------------------------===//
3686+ // AtenRemainderScalarOp
3687+ // ===----------------------------------------------------------------------===//
3688+
3689+ OpFoldResult AtenRemainderScalarOp::fold (FoldAdaptor adaptor) {
3690+ auto resultTy = dyn_cast_or_null<ValueTensorType>(getType ());
3691+ if (!resultTy || !resultTy.hasDtype ()) {
3692+ return nullptr ;
3693+ }
3694+
3695+ auto unsign = false ;
3696+ if (isa<mlir::IntegerType>(resultTy.getDtype ())) {
3697+ unsign = cast<IntegerType>(resultTy.getDtype ()).isUnsigned ();
3698+ }
3699+ auto fpFold = [](llvm::ArrayRef<double > inputs) {
3700+ assert (inputs.size () == 2 );
3701+ return std::fmod (inputs[0 ], inputs[1 ]);
3702+ };
3703+
3704+ auto intFold = [unsign](llvm::ArrayRef<APInt> inputs) {
3705+ assert (inputs.size () == 2 );
3706+ auto ret = unsign ? inputs[0 ].urem (inputs[1 ]) : inputs[0 ].srem (inputs[1 ]);
3707+ return ret;
3708+ };
3709+
3710+ return naryFolderHelper (adaptor.getOperands (), getType (), fpFold, intFold);
3711+ }
3712+
36153713// ===----------------------------------------------------------------------===//
36163714// AtenAddIntOp
36173715// ===----------------------------------------------------------------------===//
@@ -4313,6 +4411,42 @@ void AtenIntTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
43134411 });
43144412}
43154413
4414+ // ===----------------------------------------------------------------------===//
4415+ // AtenIntTensorOp
4416+ // ===----------------------------------------------------------------------===//
4417+
4418+ OpFoldResult AtenIntTensorOp::fold (FoldAdaptor adaptor) {
4419+ auto value = adaptor.getA ();
4420+ auto dense = dyn_cast_or_null<DenseElementsAttr>(value);
4421+ if (!dense || !dense.isSplat ()) {
4422+ return nullptr ;
4423+ }
4424+
4425+ auto splat = dense.getSplatValue <Attribute>();
4426+ if (auto intAttr = dyn_cast<IntegerAttr>(splat)) {
4427+ auto type = getType ();
4428+ if (!isa<mlir::IntegerType>(type)) {
4429+ return nullptr ;
4430+ }
4431+
4432+ if (type.isSignlessInteger ()) {
4433+ return getI64IntegerAttr (getContext (), intAttr.getInt ());
4434+ } else if (type.isSignedInteger ()) {
4435+ return getI64IntegerAttr (getContext (), intAttr.getSInt ());
4436+ } else {
4437+ return getI64IntegerAttr (getContext (), intAttr.getUInt ());
4438+ }
4439+ }
4440+
4441+ if (auto floatAttr = dyn_cast<FloatAttr>(splat)) {
4442+ return getI64IntegerAttr (
4443+ getContext (),
4444+ static_cast <long >(floatAttr.getValue ().convertToDouble ()));
4445+ }
4446+
4447+ return nullptr ;
4448+ }
4449+
43164450// ===----------------------------------------------------------------------===//
43174451// AtenFloatTensorOp
43184452// ===----------------------------------------------------------------------===//
0 commit comments