@@ -120,21 +120,47 @@ Tensor& opt_mul_out(
120120 out,
121121 " Failed to resize output tensor." );
122122
123- ET_SWITCH_REALB_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
124- using Vec = executorch::vec::Vectorized<CTYPE>;
125- executorch::vec::map2<CTYPE>(
126- [](Vec x, Vec y) { return x * y; },
127- out.mutable_data_ptr <CTYPE>(),
128- a.const_data_ptr <CTYPE>(),
129- b.const_data_ptr <CTYPE>(),
130- out.numel ());
131- });
123+ if (executorch::runtime::isComplexType (out_type)) {
124+ ET_KERNEL_CHECK (
125+ ctx, a_type == b_type && a_type == out_type, InvalidArgument, out);
126+
127+ ET_SWITCH_COMPLEXH_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
128+ using Vec = executorch::vec::Vectorized<CTYPE>;
129+ executorch::vec::map2<CTYPE>(
130+ [](Vec x, Vec y) { return x * y; },
131+ out.mutable_data_ptr <CTYPE>(),
132+ a.const_data_ptr <CTYPE>(),
133+ b.const_data_ptr <CTYPE>(),
134+ out.numel ());
135+ });
136+ } else {
137+ ET_SWITCH_REALB_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
138+ using Vec = executorch::vec::Vectorized<CTYPE>;
139+ executorch::vec::map2<CTYPE>(
140+ [](Vec x, Vec y) { return x * y; },
141+ out.mutable_data_ptr <CTYPE>(),
142+ a.const_data_ptr <CTYPE>(),
143+ b.const_data_ptr <CTYPE>(),
144+ out.numel ());
145+ });
146+ }
132147 } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone ) {
133- ET_SWITCH_REALB_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
134- auto mul_lambda = [](auto x, auto y) { return x * y; };
135- return torch::executor::handle_broadcast_elementwise<CTYPE>(
136- ctx, mul_lambda, a, b, out, selected_optimized_path);
137- });
148+ if (executorch::runtime::isComplexType (out_type)) {
149+ ET_KERNEL_CHECK (
150+ ctx, a_type == b_type && a_type == out_type, InvalidArgument, out);
151+
152+ ET_SWITCH_COMPLEXH_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
153+ auto mul_lambda = [](auto x, auto y) { return x * y; };
154+ return torch::executor::handle_broadcast_elementwise<CTYPE>(
155+ ctx, mul_lambda, a, b, out, selected_optimized_path);
156+ });
157+ } else {
158+ ET_SWITCH_REALB_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
159+ auto mul_lambda = [](auto x, auto y) { return x * y; };
160+ return torch::executor::handle_broadcast_elementwise<CTYPE>(
161+ ctx, mul_lambda, a, b, out, selected_optimized_path);
162+ });
163+ }
138164 } else {
139165 ScalarType common_type =
140166 promoteTypes (a_type, b_type, /* half_to_float*/ true );
@@ -146,26 +172,42 @@ Tensor& opt_mul_out(
146172 InvalidArgument,
147173 out);
148174
149- ET_SWITCH_REALHBBF16_TYPES (a_type, ctx, " mul.out" , CTYPE_A, [&]() {
150- ET_SWITCH_REALHBBF16_TYPES (b_type, ctx, " mul.out" , CTYPE_B, [&]() {
151- using CTYPE_IN = typename torch::executor::
152- promote_types<CTYPE_A, CTYPE_B, /* half_to_float*/ true >::type;
153- ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
154- ET_SWITCH_REALHBBF16_TYPES (out_type, ctx, " mul.out" , CTYPE_OUT, [&]() {
155- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
156- [](const CTYPE_A val_a, const CTYPE_B val_b) {
157- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
158- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
159- CTYPE_IN value = a_casted * b_casted;
160-
161- return static_cast <CTYPE_OUT>(value);
162- },
163- a,
164- b,
165- out);
175+ if (executorch::runtime::isComplexType (a_type) ||
176+ executorch::runtime::isComplexType (b_type) ||
177+ executorch::runtime::isComplexType (out_type)) {
178+ ET_KERNEL_CHECK (
179+ ctx, a_type == b_type && a_type == out_type, InvalidArgument, out);
180+
181+ ET_SWITCH_COMPLEXH_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
182+ apply_binary_elementwise_fn<CTYPE, CTYPE, CTYPE>(
183+ [](const CTYPE val_a, const CTYPE val_b) { return val_a * val_b; },
184+ a,
185+ b,
186+ out);
187+ });
188+ } else {
189+ ET_SWITCH_REALHBBF16_TYPES (a_type, ctx, " mul.out" , CTYPE_A, [&]() {
190+ ET_SWITCH_REALHBBF16_TYPES (b_type, ctx, " mul.out" , CTYPE_B, [&]() {
191+ using CTYPE_IN = typename torch::executor::
192+ promote_types<CTYPE_A, CTYPE_B, /* half_to_float*/ true >::type;
193+ ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
194+ ET_SWITCH_REALHBBF16_TYPES (
195+ out_type, ctx, " mul.out" , CTYPE_OUT, [&]() {
196+ apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
197+ [](const CTYPE_A val_a, const CTYPE_B val_b) {
198+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
199+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
200+ CTYPE_IN value = a_casted * b_casted;
201+
202+ return static_cast <CTYPE_OUT>(value);
203+ },
204+ a,
205+ b,
206+ out);
207+ });
166208 });
167209 });
168- });
210+ }
169211 }
170212
171213 return out;
0 commit comments