Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5236,6 +5236,7 @@ def Torch_AtenPowTensorScalarOp : Torch_Op<"aten.pow.Tensor_Scalar", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenPowTensorTensorOp : Torch_Op<"aten.pow.Tensor_Tensor", [
Expand All @@ -5260,6 +5261,7 @@ def Torch_AtenPowTensorTensorOp : Torch_Op<"aten.pow.Tensor_Tensor", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenPowScalarOp : Torch_Op<"aten.pow.Scalar", [
Expand Down
9 changes: 9 additions & 0 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,15 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
pow.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value exp = operands[1];
Type expType = exp.getType();
if (!expType.isIntOrFloat()) {
pow.emitError("unimplemented: exp type neither float nor int");
return nullptr;
}
if (isa<mlir::IntegerType>(expType)) {
return b.create<math::FPowIOp>(loc, payloadArgs[0], exp);
}
Type dtype = cast<ValueTensorType>(pow.getSelf().getType()).getDtype();
Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype);
return b.create<math::PowFOp>(loc, payloadArgs[0], expPromoted);
Expand Down
60 changes: 60 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "llvm/Support/Debug.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -2421,6 +2422,65 @@ OpFoldResult AtenReshapeOp::fold(FoldAdaptor adaptor) {
return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenPowTensorScalarOp
//===----------------------------------------------------------------------===//

void AtenPowTensorScalarOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
// If the exponent is a float representation of an int,
// convert the exponent to an int
patterns.add(+[](AtenPowTensorScalarOp op, PatternRewriter &rewriter) {
auto exp = getAsOpFoldResult(op.getExponent());
auto baseAttr = dyn_cast<mlir::Attribute>(exp);
auto floatAttr = dyn_cast_or_null<mlir::FloatAttr>(baseAttr);
if (!floatAttr)
return failure();
double expValue = floatAttr.getValueAsDouble();
auto truncValue = static_cast<int64_t>(expValue);
if (expValue != static_cast<double>(truncValue))
return failure();
Value IRScalar =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), truncValue);
op->setOperand(1, IRScalar);
return success();
});
}

//===----------------------------------------------------------------------===//
// AtenPowTensorTensorOp
//===----------------------------------------------------------------------===//

void AtenPowTensorTensorOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
// If the exponent is a single element constant, convert to
// AtenPowTensorScalar.
patterns.add(+[](AtenPowTensorTensorOp op, PatternRewriter &rewriter) {
OpFoldResult exp = getAsOpFoldResult(op.getExponent());
auto expAttr = dyn_cast<Attribute>(exp);
auto attr = dyn_cast_or_null<DenseElementsAttr>(expAttr);
if (!attr || attr.getNumElements() != 1)
return failure();
auto elem = *attr.value_begin<Attribute>();
auto intAttr = dyn_cast<mlir::IntegerAttr>(elem);
auto floatAttr = dyn_cast<mlir::FloatAttr>(elem);
if (!intAttr && !floatAttr)
return failure();
Value IRScalar;
if (intAttr)
IRScalar = rewriter.create<Torch::ConstantIntOp>(
op.getLoc(), getIntAttrAsSigned(intAttr));
if (floatAttr) {
double expValue = floatAttr.getValueAsDouble();
IRScalar = rewriter.create<Torch::ConstantFloatOp>(op.getLoc(),
APFloat(expValue));
}
rewriter.replaceOpWithNewOp<AtenPowTensorScalarOp>(op, op.getType(),
op.getSelf(), IRScalar);
return success();
});
}

//===----------------------------------------------------------------------===//
// AtenSelectIntOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,12 @@ def emit_with_mutating_variants(key, **kwargs):
has_canonicalizer=True,
)
emit("aten::gelu : (Tensor, str) -> (Tensor)")
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::pow.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)")
emit(
"aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True
)
emit(
"aten::pow.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True
)
emit("aten::pow.Scalar : (Scalar, Tensor) -> (Tensor)")
emit("aten::float_power.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)")
emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)")
Expand Down
20 changes: 20 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,26 @@ func.func @torch.aten.remainder.int() -> !torch.int {
return %ret : !torch.int
}

// CHECK-LABEL: func.func @torch.aten.pow.Tensor_Tensor$canonicalize
// CHECK: %[[SCALAR_EXP:.*]] = torch.constant.float 3.5
// CHECK: %[[POW:.*]] = torch.aten.pow.Tensor_Scalar %arg0, %[[SCALAR_EXP]]
// CHECK: return %[[POW]]
func.func @torch.aten.pow.Tensor_Tensor$canonicalize(%arg0 : !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%exponent = torch.vtensor.literal(dense<3.500000e+00> : tensor<f32>) : !torch.vtensor<[],f32>
%pow = torch.aten.pow.Tensor_Tensor %arg0, %exponent : !torch.vtensor<[?,?],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?],f32>
return %pow : !torch.vtensor<[?,?],f32>
}

// CHECK-LABEL: func.func @torch.aten.pow.Tensor_Scalar$canonicalize
// CHECK: %[[INT_EXP:.*]] = torch.constant.int 3
// CHECK: %[[POW:.*]] = torch.aten.pow.Tensor_Scalar %arg0, %[[INT_EXP]]
// CHECK: return %[[POW]]
func.func @torch.aten.pow.Tensor_Scalar$canonicalize(%arg0 : !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%float_exponent = torch.constant.float 3.0
%pow = torch.aten.pow.Tensor_Scalar %arg0, %float_exponent : !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
return %pow : !torch.vtensor<[?,?],f32>
}

// CHECK-LABEL: func.func @torch.aten.pow.int_float() -> !torch.float {
// CHECK: %[[FLOAT_8:.*]] = torch.constant.float 8.000000e+00
// CHECK: return %[[FLOAT_8]] : !torch.float
Expand Down
Loading