@@ -73,74 +73,91 @@ Tensor& clamp_out(
7373 const exec_aten::optional<Scalar>& min_opt,
7474 const exec_aten::optional<Scalar>& max_opt,
7575 Tensor& out) {
76- (void )ctx;
76+ ET_KERNEL_CHECK (
77+ ctx,
78+ (executorch::runtime::tensor_is_realhbbf16_type (in) &&
79+ executorch::runtime::tensor_is_realhbbf16_type (out)),
80+ InvalidArgument,
81+ out);
82+
83+ bool has_min = min_opt.has_value ();
84+ bool has_max = max_opt.has_value ();
7785
7886 ET_KERNEL_CHECK_MSG (
7987 ctx,
80- resize_tensor (out, in. sizes ()) == Error::Ok ,
88+ has_min || has_max ,
8189 InvalidArgument,
8290 out,
83- " Failed to resize output tensor." );
84-
85- ET_KERNEL_CHECK (
86- ctx, tensors_have_same_dim_order (in, out), InvalidArgument, out);
91+ " At least one of 'min' or 'max' must not be None" );
8792
93+ // Input Dtypes
8894 ScalarType in_type = in.scalar_type ();
89- ScalarType min_type = in_type;
90- ScalarType max_type = in_type;
91- ScalarType common_type = in_type;
95+ ScalarType min_type =
96+ has_min ? utils::get_scalar_dtype (min_opt.value ()) : in_type;
97+ ScalarType max_type =
98+ has_max ? utils::get_scalar_dtype (max_opt.value ()) : in_type;
9299 ScalarType out_type = out.scalar_type ();
93100
94- bool has_min = min_opt.has_value ();
101+ // Common Dtype
102+ ScalarType common_type = in_type;
95103 if (has_min) {
96- min_type = utils::get_scalar_dtype (min_opt.value ());
97104 common_type = utils::promote_type_with_scalar (common_type, min_opt.value ());
105+ }
106+ if (has_max) {
107+ common_type = utils::promote_type_with_scalar (common_type, max_opt.value ());
108+ }
109+
110+ // Check Common Dtype
111+ ET_KERNEL_CHECK (ctx, common_type == out_type, InvalidArgument, out);
112+
113+ // Check Scalar Bounds
114+ if (has_min) {
98115 ET_KERNEL_CHECK (
99116 ctx,
100117 check_bounds (min_opt.value (), min_type, out_type, " minimum" ),
101118 InvalidArgument,
102119 out);
103120 }
104- bool has_max = max_opt.has_value ();
105121 if (has_max) {
106- max_type = utils::get_scalar_dtype (max_opt.value ());
107- common_type = utils::promote_type_with_scalar (common_type, max_opt.value ());
108122 ET_KERNEL_CHECK (
109123 ctx,
110124 check_bounds (max_opt.value (), max_type, out_type, " maximum" ),
111125 InvalidArgument,
112126 out);
113127 }
114128
115- ET_KERNEL_CHECK_MSG (
116- ctx,
117- has_min || has_max,
118- InvalidArgument,
119- out,
120- " At least one of 'min' or 'max' must not be None" );
129+ // Check Dim Order
130+ ET_KERNEL_CHECK (
131+ ctx, tensors_have_same_dim_order (in, out), InvalidArgument, out);
121132
122- ET_KERNEL_CHECK (ctx, common_type == out_type, InvalidArgument, out);
133+ // Resize
134+ ET_KERNEL_CHECK (
135+ ctx, resize_tensor (out, in.sizes ()) == Error::Ok, InvalidArgument, out);
136+
137+ // Compute Dtype
138+ ScalarType compute_type = utils::get_compute_type (common_type);
123139
140+ // @lint-ignore CLANGTIDY facebook-hte-CArray
124141 static constexpr const char op_name[] = " clamp.out" ;
125142
126- ET_SWITCH_REALHB_TYPES (common_type , ctx, op_name, CTYPE_COMMON , [&]() {
127- utils::apply_unitensor_elementwise_fn<CTYPE_COMMON , op_name>(
128- [has_min, min_opt, has_max, max_opt](const CTYPE_COMMON val_in) {
129- CTYPE_COMMON val_out = val_in;
143+ ET_SWITCH_REALB_TYPES (compute_type , ctx, op_name, CTYPE_COMPUTE , [&]() {
144+ utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE , op_name>(
145+ [has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) {
146+ CTYPE_COMPUTE val_out = val_in;
130147 if (has_min) {
131148 val_out = utils::max_override (
132- val_out, utils::scalar_to<CTYPE_COMMON >(min_opt.value ()));
149+ val_out, utils::scalar_to<CTYPE_COMPUTE >(min_opt.value ()));
133150 }
134151 if (has_max) {
135152 val_out = utils::min_override (
136- val_out, utils::scalar_to<CTYPE_COMMON >(max_opt.value ()));
153+ val_out, utils::scalar_to<CTYPE_COMPUTE >(max_opt.value ()));
137154 }
138155 return val_out;
139156 },
140157 in,
141158 utils::SupportedTensorDtypes::REALHBBF16,
142159 out,
143- utils::SupportedTensorDtypes::REALHBBF16 );
160+ utils::SupportedTensorDtypes::SAME_AS_COMMON );
144161 });
145162
146163 return out;
@@ -152,8 +169,6 @@ Tensor& clamp_tensor_out(
152169 const exec_aten::optional<Tensor>& min_opt,
153170 const exec_aten::optional<Tensor>& max_opt,
154171 Tensor& out) {
155- (void )ctx;
156-
157172 bool has_min = min_opt.has_value ();
158173 bool has_max = max_opt.has_value ();
159174
@@ -167,42 +182,55 @@ Tensor& clamp_tensor_out(
167182 const Tensor& min = has_min ? min_opt.value () : in;
168183 const Tensor& max = has_max ? max_opt.value () : in;
169184
185+ ET_KERNEL_CHECK (
186+ ctx,
187+ (executorch::runtime::tensor_is_realhbbf16_type (in) &&
188+ executorch::runtime::tensor_is_realhbbf16_type (min) &&
189+ executorch::runtime::tensor_is_realhbbf16_type (max) &&
190+ executorch::runtime::tensor_is_realhbbf16_type (out)),
191+ InvalidArgument,
192+ out);
193+
194+ // Common Dtype
195+ ScalarType common_type = in.scalar_type ();
196+ if (has_min) {
197+ common_type = promoteTypes (common_type, min.scalar_type ());
198+ }
199+ if (has_max) {
200+ common_type = promoteTypes (common_type, max.scalar_type ());
201+ }
202+
203+ // Check Common Dtype
204+ ET_KERNEL_CHECK (
205+ ctx, canCast (common_type, out.scalar_type ()), InvalidArgument, out);
206+
207+ // Check Dim Order
170208 ET_KERNEL_CHECK (
171209 ctx,
172210 tensors_have_same_dim_order (in, min, max, out),
173211 InvalidArgument,
174212 out);
175213
214+ // Resize
176215 ET_KERNEL_CHECK (
177216 ctx,
178217 resize_to_broadcast_target_size (in, min, max, out) == Error::Ok,
179218 InvalidArgument,
180219 out);
181220
182- ScalarType in_type = in.scalar_type ();
183- ScalarType min_type = min.scalar_type ();
184- ScalarType max_type = max.scalar_type ();
185- ScalarType common_type = in_type;
186- ScalarType out_type = out.scalar_type ();
187-
188- if (has_min) {
189- common_type = promoteTypes (common_type, min_type, /* half_to_float*/ true );
190- }
191- if (has_max) {
192- common_type = promoteTypes (common_type, max_type, /* half_to_float*/ true );
193- }
194-
195- ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
221+ // Compute Dtype
222+ ScalarType compute_type = utils::get_compute_type (common_type);
196223
224+ // @lint-ignore CLANGTIDY facebook-hte-CArray
197225 static constexpr const char op_name[] = " clamp.Tensor_out" ;
198226
199- ET_SWITCH_REALHB_TYPES (common_type , ctx, op_name, CTYPE_COMMON , [&]() {
200- utils::apply_tritensor_elementwise_fn<CTYPE_COMMON , op_name>(
227+ ET_SWITCH_REALB_TYPES (compute_type , ctx, op_name, CTYPE_COMPUTE , [&]() {
228+ utils::apply_tritensor_elementwise_fn<CTYPE_COMPUTE , op_name>(
201229 [has_min, has_max](
202- const CTYPE_COMMON val_in,
203- const CTYPE_COMMON val_min,
204- const CTYPE_COMMON val_max) {
205- CTYPE_COMMON val_out = val_in;
230+ const CTYPE_COMPUTE val_in,
231+ const CTYPE_COMPUTE val_min,
232+ const CTYPE_COMPUTE val_max) {
233+ CTYPE_COMPUTE val_out = val_in;
206234 if (has_min) {
207235 val_out = utils::max_override (val_out, val_min);
208236 }
0 commit comments