Skip to content

Commit cce490d

Browse files
anupgangwarAnup Gangwar
andauthored
* [tosa] Support for Rsqrt legalization (#480)
Signed-off-by: Anup Gangwar <[email protected]> Co-authored-by: Anup Gangwar <[email protected]>
1 parent 6dabf18 commit cce490d

File tree

3 files changed

+15
-0
lines changed

3 files changed

+15
-0
lines changed

e2e_testing/torchscript/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,5 @@
4343
"BoolTensorReturnFalseModule_basic",
4444
"BoolTensorReturnTrueModule_basic",
4545
"BoolTensorReturnMixedModule_basic",
46+
"ElementwiseRsqrtModule_basic",
4647
}

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
442442
patterns.add<ConvertAtenUnaryOp<AtenOp, TosaOp>>(typeConverter, context);
443443
INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp)
444444
INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp)
445+
INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp)
445446
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
446447
#undef INSERT_UNARY_PATTERN
447448

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,16 @@ func @test_reduce_any$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtens
285285
return %0 : !torch.vtensor<[1],i1>
286286
}
287287

288+
// -----
289+
290+
// CHECK-LABEL: func @torch.aten.rsqrt$basic(
291+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
292+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
293+
// CHECK: %[[VAL_2:.*]] = "tosa.rsqrt"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
294+
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
295+
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
296+
// CHECK: }
297+
func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
298+
%0 = torch.aten.rsqrt %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
299+
return %0 : !torch.vtensor<[?,?],f32>
300+
}

0 commit comments

Comments
 (0)