Skip to content

Commit c1a9417

Browse files
[EE/BE] add float variant to unary_ufunc_realhbbf16_to_floathbf16
Differential Revision: [D77759060](https://our.internmc.facebook.com/intern/diff/D77759060/) ghstack-source-id: 294877446 Pull Request resolved: #12277
1 parent ba19c75 commit c1a9417

30 files changed

+60
-101
lines changed

backends/cadence/fusion_g3/operators/op_exp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ Tensor& exp_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
6060
return out;
6161
} else {
6262
return torch::executor::native::internal::
63-
unary_ufunc_realhbbf16_to_floathbf16(std::exp, ctx, in, out);
63+
unary_ufunc_realhbbf16_to_floathbf16(std::exp, std::exp, ctx, in, out);
6464
}
6565
}
6666

backends/cadence/fusion_g3/operators/op_rsqrt.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ namespace native {
2727

2828
namespace {
2929

30-
double rsqrt(double x) {
30+
template <typename T>
31+
T rsqrt(T x) {
3132
return 1.0 / std::sqrt(x);
3233
}
3334

@@ -61,11 +62,11 @@ Tensor& rsqrt_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
6162
return out;
6263
} else {
6364
return torch::executor::native::internal::
64-
unary_ufunc_realhbbf16_to_floathbf16(rsqrt, ctx, in, out);
65+
unary_ufunc_realhbbf16_to_floathbf16(rsqrt, rsqrt, ctx, in, out);
6566
}
6667
}
6768

6869
} // namespace native
6970
} // namespace G3
7071
} // namespace impl
71-
} // namespace cadence
72+
} // namespace cadence

backends/cadence/fusion_g3/operators/op_sqrt.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ Tensor& sqrt_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
5555
return out;
5656
} else {
5757
return torch::executor::native::internal::
58-
unary_ufunc_realhbbf16_to_floathbf16(std::sqrt, ctx, in, out);
58+
unary_ufunc_realhbbf16_to_floathbf16(
59+
std::sqrt, std::sqrt, ctx, in, out);
5960
}
6061
}
6162

backends/cadence/fusion_g3/operators/op_tanh.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ Tensor& tanh_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
5555
return out;
5656
} else {
5757
return torch::executor::native::internal::
58-
unary_ufunc_realhbbf16_to_floathbf16(std::tanh, ctx, in, out);
58+
unary_ufunc_realhbbf16_to_floathbf16(
59+
std::tanh, std::tanh, ctx, in, out);
5960
}
6061
}
6162

backends/cadence/hifi/operators/op_rsqrt.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ namespace HiFi {
2121
namespace native {
2222
namespace {
2323

24-
double rsqrt(double x) {
24+
template <typename T>
25+
T rsqrt(T x) {
2526
return 1.0 / std::sqrt(x);
2627
}
2728

@@ -46,7 +47,7 @@ Tensor& rsqrt_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
4647
}
4748

4849
return torch::executor::native::internal::
49-
unary_ufunc_realhbbf16_to_floathbf16(rsqrt, ctx, in, out);
50+
unary_ufunc_realhbbf16_to_floathbf16(rsqrt, rsqrt, ctx, in, out);
5051
}
5152

5253
} // namespace native

backends/cadence/hifi/operators/op_tanh.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ Tensor& tanh_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
3535
}
3636

3737
return torch::executor::native::internal::
38-
unary_ufunc_realhbbf16_to_floathbf16(std::tanh, ctx, in, out);
38+
unary_ufunc_realhbbf16_to_floathbf16(std::tanh, std::tanh, ctx, in, out);
3939
}
4040

4141
} // namespace native
4242
} // namespace HiFi
4343
} // namespace impl
44-
} // namespace cadence
44+
} // namespace cadence

kernels/portable/cpu/op_acos.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@ namespace torch {
1414
namespace executor {
1515
namespace native {
1616

17-
Tensor& acos_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
18-
return internal::unary_ufunc_realhbbf16_to_floathbf16(
19-
std::acos, ctx, in, out);
20-
}
17+
DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(acos_out, std::acos)
2118

2219
} // namespace native
2320
} // namespace executor

kernels/portable/cpu/op_acosh.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@ namespace torch {
1414
namespace executor {
1515
namespace native {
1616

17-
Tensor& acosh_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
18-
return internal::unary_ufunc_realhbbf16_to_floathbf16(
19-
std::acosh, ctx, in, out);
20-
}
17+
DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(acosh_out, std::acosh)
2118

2219
} // namespace native
2320
} // namespace executor

kernels/portable/cpu/op_asin.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@ namespace torch {
1414
namespace executor {
1515
namespace native {
1616

17-
Tensor& asin_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
18-
return internal::unary_ufunc_realhbbf16_to_floathbf16(
19-
std::asin, ctx, in, out);
20-
}
17+
DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(asin_out, std::asin)
2118

2219
} // namespace native
2320
} // namespace executor

kernels/portable/cpu/op_asinh.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@ namespace torch {
1414
namespace executor {
1515
namespace native {
1616

17-
Tensor& asinh_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
18-
return internal::unary_ufunc_realhbbf16_to_floathbf16(
19-
std::asinh, ctx, in, out);
20-
}
17+
DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(asinh_out, std::asinh)
2118

2219
} // namespace native
2320
} // namespace executor

0 commit comments

Comments
 (0)