Skip to content

Commit d69d29b

Browse files
Anup Gangwarcathyzhyi
authored andcommitted
* [tosa] Support for AtenPowTensorScalarOp with constant Scalar as input
Signed-off-by: Anup Gangwar <[email protected]>
1 parent 077e55d commit d69d29b

File tree

5 files changed

+67
-0
lines changed

5 files changed

+67
-0
lines changed

e2e_testing/torchscript/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,5 @@
5050
"SqueezeDimModule_identity",
5151
"SqueezeDimModule_unitDim",
5252
"ReturnTwoTensorF32I64_basic",
53+
"ElementwisePowModule_basic",
5354
}

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,49 @@ class ConvertAtenSqueezeAllDimsOp : public ConvertAtenSqueezeOp<AtenOpT> {
635635
}
636636
};
637637

638+
// FIXME(AG): This will eventually go into a Tosa*Utils file
639+
// Convert an fp32 scalar into tosa fp32 tensor.
640+
static LogicalResult
641+
tosaF32TensorFromTorchFloat(ConversionPatternRewriter &rewriter, Operation *op,
642+
Value torchScalarValue, Value &tosaTensor) {
643+
double scalarValue;
644+
645+
if (!matchPattern(torchScalarValue, m_TorchConstantFloat(&scalarValue)))
646+
return failure();
647+
648+
// Construct a tosa.const
649+
tosaTensor =
650+
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, scalarValue);
651+
652+
return success();
653+
}
654+
655+
template <>
656+
LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
657+
AtenPowTensorScalarOp op, OpAdaptor adaptor,
658+
ConversionPatternRewriter &rewriter) const {
659+
660+
Value self = adaptor.self();
661+
auto selfTy = self.getType().template cast<RankedTensorType>();
662+
663+
if (!selfTy)
664+
return op.emitError("Only ranked tensor types supported in TOSA Pow");
665+
666+
if (!selfTy.getElementType().isa<mlir::FloatType>())
667+
return op.emitError("Only floating-point datatype legalization supported");
668+
669+
Value expTensor;
670+
Value expScalar = op.exponent();
671+
if (failed(tosaF32TensorFromTorchFloat(rewriter, op.getOperation(), expScalar,
672+
expTensor)))
673+
return op.emitError("Currently only scalar constants are supported for "
674+
"conversion in TOSA Pow operation");
675+
676+
rewriter.replaceOpWithNewOp<tosa::PowOp>(
677+
op, getTypeConverter()->convertType(op.getType()), self, expTensor);
678+
679+
return success();
680+
}
638681
} // namespace
639682

640683
// -----------------------------------------------------------------------------
@@ -740,6 +783,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
740783
INSERT_ATENOP_PATTERN(AtenMulTensorOp);
741784
INSERT_ATENOP_PATTERN(AtenDivTensorOp);
742785
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
786+
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
743787
#undef INSERT_ATENOP_PATTERN
744788

745789
if (failed(applyPartialConversion(getOperation(), target,

lib/Dialect/TorchConversion/Transforms/Passes.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
1313
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
1414
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
15+
#include "mlir/Conversion/Passes.h"
1516
#include "mlir/Pass/PassManager.h"
1617
#include "mlir/Transforms/Passes.h"
1718
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
@@ -101,6 +102,9 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
101102
pm.addNestedPass<FuncOp>(createCSEPass());
102103
}
103104

105+
// Add the ToStandard pass for lowering some ops
106+
pm.addNestedPass<FuncOp>(createTosaToStandard());
107+
104108
// Finish the type conversion from `torch` types to the types of the
105109
// TOSA backend contract.
106110
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());

lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class VerifyTosaBackendContractPass
4444
target.addLegalDialect<tosa::TosaDialect>();
4545
target.addDynamicallyLegalOp<tensor::CastOp>(opHasLegalTypes);
4646
target.addDynamicallyLegalOp<arith::ExtSIOp>(opHasLegalTypes);
47+
target.addDynamicallyLegalOp<arith::ConstantOp>(opHasLegalTypes);
4748

4849
RewritePatternSet patterns(context);
4950
if (failed(applyFullConversion(module, target, std::move(patterns)))) {

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,3 +330,20 @@ func @torch.aten.minimum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v
330330
%0 = torch.aten.minimum %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
331331
return %0 : !torch.vtensor<[?,?],f32>
332332
}
333+
334+
// -----
335+
336+
// CHECK-LABEL: func @torch.aten.pow.Tensor_Scalar$basic(
337+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
338+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
339+
// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00
340+
// CHECK: %[[VAL_3:.*]] = "tosa.const"() {value = dense<3.123400e+00> : tensor<f32>} : () -> tensor<f32>
341+
// CHECK: %[[VAL_4:.*]] = "tosa.pow"(%[[VAL_1]], %[[VAL_3]]) : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
342+
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
343+
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
344+
// CHECK: }
345+
func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
346+
%fp0 = torch.constant.float 3.123400e+00
347+
%0 = torch.aten.pow.Tensor_Scalar %arg0, %fp0 : !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
348+
return %0 : !torch.vtensor<[?,?],f32>
349+
}

0 commit comments

Comments
 (0)