@@ -48,137 +48,6 @@ namespace impl {
4848namespace HiFi {
4949namespace native {
5050
51- namespace {
52-
53- template <typename CTYPE_VAL, typename CTYPE_OUT, typename CTYPE_CAST>
54- /* * Check if val, when cast to CTYPE_CAST, is not in the range of CTYPE_OUT */
55- bool is_out_of_bounds (CTYPE_VAL val) {
56- const CTYPE_CAST val_cast = static_cast <CTYPE_CAST>(val);
57- return val_cast < std::numeric_limits<CTYPE_OUT>::lowest () ||
58- val_cast > std::numeric_limits<CTYPE_OUT>::max ();
59- }
60-
61- ET_NODISCARD bool check_bounds (
62- const Scalar& val_scalar,
63- const ScalarType& val_type,
64- const ScalarType& out_type,
65- const char * val_name) {
66- auto is_valid = true ;
67-
68- ET_SWITCH_SCALAR_OBJ_TYPES (val_type, ctx, " clamp.out" , CTYPE_VAL, [&]() {
69- CTYPE_VAL val = 0 ;
70- extract_scalar (val_scalar, &val);
71- if (isIntegralType (out_type, /* includeBool=*/ false )) {
72- ET_SWITCH_INT_TYPES (out_type, ctx, " clamp.out" , CTYPE_OUT, [&]() {
73- if (is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, long >(val)) {
74- ET_LOG (Error, " %s value out of bounds" , val_name);
75- is_valid = false ;
76- }
77- });
78- } else if (isFloatingType (out_type)) {
79- ET_SWITCH_FLOATH_TYPES (out_type, ctx, " clamp" , CTYPE_OUT, [&]() {
80- if (std::isfinite (val) &&
81- is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, double >(val)) {
82- ET_LOG (Error, " %s value out of bounds" , val_name);
83- is_valid = false ;
84- }
85- });
86- }
87- });
88-
89- return is_valid;
90- }
91-
92- } // namespace
93-
94- Tensor& clamp_out (
95- KernelRuntimeContext& ctx,
96- const Tensor& in,
97- const exec_aten::optional<Scalar>& min_opt,
98- const exec_aten::optional<Scalar>& max_opt,
99- Tensor& out) {
100- bool has_min = min_opt.has_value ();
101- bool has_max = max_opt.has_value ();
102-
103- ET_KERNEL_CHECK_MSG (
104- ctx,
105- has_min || has_max,
106- InvalidArgument,
107- out,
108- " At least one of 'min' or 'max' must not be None" );
109-
110- // Input Dtypes
111- ScalarType in_type = in.scalar_type ();
112- ScalarType min_type = has_min ? get_scalar_dtype (min_opt.value ()) : in_type;
113- ScalarType max_type = has_max ? get_scalar_dtype (max_opt.value ()) : in_type;
114- ScalarType out_type = out.scalar_type ();
115-
116- // Common Dtype
117- ScalarType common_type = in_type;
118- if (has_min) {
119- common_type = promote_type_with_scalar (common_type, min_opt.value ());
120- }
121- if (has_max) {
122- common_type = promote_type_with_scalar (common_type, max_opt.value ());
123- }
124-
125- // Check Common Dtype
126- ET_KERNEL_CHECK (ctx, common_type == out_type, InvalidArgument, out);
127-
128- // Check Scalar Bounds
129- if (has_min) {
130- ET_KERNEL_CHECK (
131- ctx,
132- check_bounds (min_opt.value (), min_type, out_type, " minimum" ),
133- InvalidArgument,
134- out);
135- }
136- if (has_max) {
137- ET_KERNEL_CHECK (
138- ctx,
139- check_bounds (max_opt.value (), max_type, out_type, " maximum" ),
140- InvalidArgument,
141- out);
142- }
143-
144- // Check Dim Order
145- ET_KERNEL_CHECK (
146- ctx, tensors_have_same_dim_order (in, out), InvalidArgument, out);
147-
148- // Resize
149- ET_KERNEL_CHECK (
150- ctx, resize_tensor (out, in.sizes ()) == Error::Ok, InvalidArgument, out);
151-
152- // Compute Dtype
153- ScalarType compute_type = get_compute_type (common_type);
154-
155- // @lint-ignore CLANGTIDY facebook-hte-CArray
156- static constexpr const char op_name[] = " clamp.out" ;
157-
158- ET_SWITCH_REALB_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
159- apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
160- [has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) {
161- CTYPE_COMPUTE val_out = val_in;
162- if (has_min) {
163- val_out = max_override (
164- val_out, scalar_to<CTYPE_COMPUTE>(min_opt.value ()));
165- }
166- if (has_max) {
167- val_out = min_override (
168- val_out, scalar_to<CTYPE_COMPUTE>(max_opt.value ()));
169- }
170- return val_out;
171- },
172- ctx,
173- in,
174- SupportedTensorDtypes::REALHBBF16,
175- out,
176- SupportedTensorDtypes::SAME_AS_COMMON);
177- });
178-
179- return out;
180- }
181-
18251Tensor& clamp_tensor_out (
18352 RuntimeContext& ctx,
18453 const Tensor& in,
0 commit comments