@@ -73,74 +73,90 @@ Tensor& clamp_out(
7373 const exec_aten::optional<Scalar>& min_opt,
7474 const exec_aten::optional<Scalar>& max_opt,
7575 Tensor& out) {
76- (void )ctx;
76+ ET_KERNEL_CHECK (
77+ ctx,
78+ (executorch::runtime::tensor_is_realhbbf16_type (in) &&
79+ executorch::runtime::tensor_is_realhbbf16_type (out)),
80+ InvalidArgument,
81+ out);
82+
83+ bool has_min = min_opt.has_value ();
84+ bool has_max = max_opt.has_value ();
7785
7886 ET_KERNEL_CHECK_MSG (
7987 ctx,
80- resize_tensor (out, in. sizes ()) == Error::Ok ,
88+ has_min || has_max ,
8189 InvalidArgument,
8290 out,
83- " Failed to resize output tensor." );
84-
85- ET_KERNEL_CHECK (
86- ctx, tensors_have_same_dim_order (in, out), InvalidArgument, out);
91+ " At least one of 'min' or 'max' must not be None" );
8792
93+ // Input Dtypes
8894 ScalarType in_type = in.scalar_type ();
89- ScalarType min_type = in_type;
90- ScalarType max_type = in_type;
91- ScalarType common_type = in_type;
95+ ScalarType min_type =
96+ has_min ? utils::get_scalar_dtype (min_opt.value ()) : in_type;
97+ ScalarType max_type =
98+ has_max ? utils::get_scalar_dtype (max_opt.value ()) : in_type;
9299 ScalarType out_type = out.scalar_type ();
93100
94- bool has_min = min_opt.has_value ();
101+ // Common Dtype
102+ ScalarType common_type = in_type;
95103 if (has_min) {
96- min_type = utils::get_scalar_dtype (min_opt.value ());
97104 common_type = utils::promote_type_with_scalar (common_type, min_opt.value ());
105+ }
106+ if (has_max) {
107+ common_type = utils::promote_type_with_scalar (common_type, max_opt.value ());
108+ }
109+
110+ // Check Common Dtype
111+ ET_KERNEL_CHECK (ctx, common_type == out_type, InvalidArgument, out);
112+
113+ // Check Scalar Bounds
114+ if (has_min) {
98115 ET_KERNEL_CHECK (
99116 ctx,
100117 check_bounds (min_opt.value (), min_type, out_type, " minimum" ),
101118 InvalidArgument,
102119 out);
103120 }
104- bool has_max = max_opt.has_value ();
105121 if (has_max) {
106- max_type = utils::get_scalar_dtype (max_opt.value ());
107- common_type = utils::promote_type_with_scalar (common_type, max_opt.value ());
108122 ET_KERNEL_CHECK (
109123 ctx,
110124 check_bounds (max_opt.value (), max_type, out_type, " maximum" ),
111125 InvalidArgument,
112126 out);
113127 }
114128
115- ET_KERNEL_CHECK_MSG (
116- ctx,
117- has_min || has_max,
118- InvalidArgument,
119- out,
120- " At least one of 'min' or 'max' must not be None" );
129+ // Check Dim Order
130+ ET_KERNEL_CHECK (
131+ ctx, tensors_have_same_dim_order (in, out), InvalidArgument, out);
121132
122- ET_KERNEL_CHECK (ctx, common_type == out_type, InvalidArgument, out);
133+ // Resize
134+ ET_KERNEL_CHECK (
135+ ctx, resize_tensor (out, in.sizes ()) == Error::Ok, InvalidArgument, out);
136+
137+ // Compute Dtype
138+ ScalarType compute_type = utils::get_compute_type (common_type);
123139
124140 static constexpr const char op_name[] = " clamp.out" ;
125141
126- ET_SWITCH_REALHB_TYPES (common_type , ctx, op_name, CTYPE_COMMON , [&]() {
127- utils::apply_unitensor_elementwise_fn<CTYPE_COMMON , op_name>(
128- [has_min, min_opt, has_max, max_opt](const CTYPE_COMMON val_in) {
129- CTYPE_COMMON val_out = val_in;
142+ ET_SWITCH_REALB_TYPES (compute_type , ctx, op_name, CTYPE_COMPUTE , [&]() {
143+ utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE , op_name>(
144+ [has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) {
145+ CTYPE_COMPUTE val_out = val_in;
130146 if (has_min) {
131147 val_out = utils::max_override (
132- val_out, utils::scalar_to<CTYPE_COMMON >(min_opt.value ()));
148+ val_out, utils::scalar_to<CTYPE_COMPUTE >(min_opt.value ()));
133149 }
134150 if (has_max) {
135151 val_out = utils::min_override (
136- val_out, utils::scalar_to<CTYPE_COMMON >(max_opt.value ()));
152+ val_out, utils::scalar_to<CTYPE_COMPUTE >(max_opt.value ()));
137153 }
138154 return val_out;
139155 },
140156 in,
141157 utils::SupportedTensorDtypes::REALHBBF16,
142158 out,
143- utils::SupportedTensorDtypes::REALHBBF16 );
159+ utils::SupportedTensorDtypes::SAME_AS_COMMON );
144160 });
145161
146162 return out;
@@ -152,8 +168,6 @@ Tensor& clamp_tensor_out(
152168 const exec_aten::optional<Tensor>& min_opt,
153169 const exec_aten::optional<Tensor>& max_opt,
154170 Tensor& out) {
155- (void )ctx;
156-
157171 bool has_min = min_opt.has_value ();
158172 bool has_max = max_opt.has_value ();
159173
@@ -167,42 +181,54 @@ Tensor& clamp_tensor_out(
167181 const Tensor& min = has_min ? min_opt.value () : in;
168182 const Tensor& max = has_max ? max_opt.value () : in;
169183
184+ ET_KERNEL_CHECK (
185+ ctx,
186+ (executorch::runtime::tensor_is_realhbbf16_type (in) &&
187+ executorch::runtime::tensor_is_realhbbf16_type (min) &&
188+ executorch::runtime::tensor_is_realhbbf16_type (max) &&
189+ executorch::runtime::tensor_is_realhbbf16_type (out)),
190+ InvalidArgument,
191+ out);
192+
193+ // Common Dtype
194+ ScalarType common_type = in.scalar_type ();
195+ if (has_min) {
196+ common_type = promoteTypes (common_type, min.scalar_type ());
197+ }
198+ if (has_max) {
199+ common_type = promoteTypes (common_type, max.scalar_type ());
200+ }
201+
202+ // Check Common Dtype
203+ ET_KERNEL_CHECK (
204+ ctx, canCast (common_type, out.scalar_type ()), InvalidArgument, out);
205+
206+ // Check Dim Order
170207 ET_KERNEL_CHECK (
171208 ctx,
172209 tensors_have_same_dim_order (in, min, max, out),
173210 InvalidArgument,
174211 out);
175212
213+ // Resize
176214 ET_KERNEL_CHECK (
177215 ctx,
178216 resize_to_broadcast_target_size (in, min, max, out) == Error::Ok,
179217 InvalidArgument,
180218 out);
181219
182- ScalarType in_type = in.scalar_type ();
183- ScalarType min_type = min.scalar_type ();
184- ScalarType max_type = max.scalar_type ();
185- ScalarType common_type = in_type;
186- ScalarType out_type = out.scalar_type ();
187-
188- if (has_min) {
189- common_type = promoteTypes (common_type, min_type, /* half_to_float*/ true );
190- }
191- if (has_max) {
192- common_type = promoteTypes (common_type, max_type, /* half_to_float*/ true );
193- }
194-
195- ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
220+ // Compute Dtype
221+ ScalarType compute_type = utils::get_compute_type (common_type);
196222
197223 static constexpr const char op_name[] = " clamp.Tensor_out" ;
198224
199- ET_SWITCH_REALHB_TYPES (common_type , ctx, op_name, CTYPE_COMMON , [&]() {
200- utils::apply_tritensor_elementwise_fn<CTYPE_COMMON , op_name>(
225+ ET_SWITCH_REALB_TYPES (compute_type , ctx, op_name, CTYPE_COMPUTE , [&]() {
226+ utils::apply_tritensor_elementwise_fn<CTYPE_COMPUTE , op_name>(
201227 [has_min, has_max](
202- const CTYPE_COMMON val_in,
203- const CTYPE_COMMON val_min,
204- const CTYPE_COMMON val_max) {
205- CTYPE_COMMON val_out = val_in;
228+ const CTYPE_COMPUTE val_in,
229+ const CTYPE_COMPUTE val_min,
230+ const CTYPE_COMPUTE val_max) {
231+ CTYPE_COMPUTE val_out = val_in;
206232 if (has_min) {
207233 val_out = utils::max_override (val_out, val_min);
208234 }
0 commit comments