@@ -85,6 +85,7 @@ inline void dtype_specialized_elementwise_fn_impl(
8585 static_assert (
8686 (std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
8787 ...));
88+ static constexpr auto kNumInputs = sizeof ...(inputs);
8889 // All inputs must be of type CTYPE_COMPUTE.
8990 ET_DCHECK (
9091 ((inputs.first ->scalar_type () ==
@@ -104,9 +105,8 @@ inline void dtype_specialized_elementwise_fn_impl(
104105 out.numel(),
105106 ::executorch::extension::internal::GRAIN_SIZE,
106107 [&](const auto begin, const auto end) {
107- std::array<const CTYPE_COMPUTE*, sizeof ...(inputs)>
108- inputs_data_ptrs = {
109- inputs.first ->template const_data_ptr <CTYPE_COMPUTE>()...};
108+ std::array<const CTYPE_COMPUTE*, kNumInputs > inputs_data_ptrs = {
109+ inputs.first ->template const_data_ptr <CTYPE_COMPUTE>()...};
110110
111111 CTYPE_OUT* const data_out = out.mutable_data_ptr <CTYPE_OUT>();
112112
@@ -119,11 +119,11 @@ inline void dtype_specialized_elementwise_fn_impl(
119119 // small-sized tests will test whether using Vectorized broke our
120120 // lambda.
121121#ifndef NDEBUG
122- std::array<Vec, sizeof ...(inputs) > loaded_inputs{};
122+ std::array<Vec, kNumInputs > loaded_inputs{};
123123#else // NDEBUG
124- std::array<CTYPE_COMPUTE, sizeof ...(inputs) > loaded_inputs{};
124+ std::array<CTYPE_COMPUTE, kNumInputs > loaded_inputs{};
125125#endif // NDEBUG
126- for (const auto input_idx : c10::irange (sizeof ...(inputs) )) {
126+ for (const auto input_idx : c10::irange (kNumInputs )) {
127127 loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
128128 }
129129#ifndef NDEBUG
@@ -136,8 +136,8 @@ inline void dtype_specialized_elementwise_fn_impl(
136136 // Main vectorized loop.
137137 for (auto idx = vectorized_begin; idx < vectorized_end;
138138 idx += Vec::size ()) {
139- std::array<Vec, sizeof ...(inputs) > loaded_vec_inputs{};
140- for (const auto input_idx : c10::irange (sizeof ...(inputs) )) {
139+ std::array<Vec, kNumInputs > loaded_vec_inputs{};
140+ for (const auto input_idx : c10::irange (kNumInputs )) {
141141 loaded_vec_inputs[input_idx] =
142142 Vec::loadu (&inputs_data_ptrs[input_idx][idx]);
143143 }
@@ -148,11 +148,11 @@ inline void dtype_specialized_elementwise_fn_impl(
148148 // Scalar epilogue.
149149 for (const auto idx : c10::irange (vectorized_end, end)) {
150150#ifndef NDEBUG
151- std::array<Vec, sizeof ...(inputs) > loaded_inputs{};
151+ std::array<Vec, kNumInputs > loaded_inputs{};
152152#else // NDEBUG
153- std::array<CTYPE_COMPUTE, sizeof ...(inputs) > loaded_inputs{};
153+ std::array<CTYPE_COMPUTE, kNumInputs > loaded_inputs{};
154154#endif // NDEBUG
155- for (const auto input_idx : c10::irange (sizeof ...(inputs) )) {
155+ for (const auto input_idx : c10::irange (kNumInputs )) {
156156 loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
157157 }
158158#ifndef NDEBUG
@@ -172,20 +172,20 @@ inline void dtype_specialized_elementwise_fn_impl(
172172 out.numel(),
173173 ::executorch::extension::internal::GRAIN_SIZE,
174174 [&](const auto begin, const auto end) {
175- std::array<const CTYPE_COMPUTE*, sizeof ...(inputs) > inputs_data_ptrs = {
175+ std::array<const CTYPE_COMPUTE*, kNumInputs > inputs_data_ptrs = {
176176 inputs.first ->template const_data_ptr <CTYPE_COMPUTE>()...};
177177
178178 CTYPE_OUT* const data_out = out.mutable_data_ptr <CTYPE_OUT>();
179179
180- const auto range = BroadcastIndexesRange<
181- sizeof ...(inputs),
182- support_noncontiguous_tensors>( out, (*inputs.first )...);
180+ const auto range =
181+ BroadcastIndexesRange< kNumInputs , support_noncontiguous_tensors>(
182+ out, (*inputs.first )...);
183183 auto begin_it = range.begin ();
184184 begin_it += begin;
185185 for (; (*begin_it)[0 ] < end; ++begin_it) {
186186 const auto & indexes = *begin_it;
187- std::array<CTYPE_COMPUTE, sizeof ...(inputs) > loaded_inputs{};
188- for (const auto idx : c10::irange (sizeof ...(inputs) )) {
187+ std::array<CTYPE_COMPUTE, kNumInputs > loaded_inputs{};
188+ for (const auto idx : c10::irange (kNumInputs )) {
189189 loaded_inputs[idx] = inputs_data_ptrs[idx][indexes[idx + 1 ]];
190190 }
191191 data_out[indexes[0 ]] = std::apply (compute_fun, loaded_inputs);
@@ -229,12 +229,14 @@ inline void apply_elementwise_fn_generic_impl(
229229 const Tensor& out,
230230 SupportedTensorDtypes out_dtypes,
231231 Args... inputs) {
232+ static constexpr auto kNumInputs = sizeof ...(inputs);
233+
232234 struct InputInfo {
233235 load_to_compute_fn<CTYPE_COMPUTE> load_to_compute;
234236 const char * data_ptr;
235237 ssize_t element_size;
236238 };
237- std::array<InputInfo, sizeof ...(inputs) > inputs_info = {(InputInfo{
239+ std::array<InputInfo, kNumInputs > inputs_info = {(InputInfo{
238240 internal::get_load_to_compute_fn<CTYPE_COMPUTE, op_name>(
239241 ctx, *inputs.first , inputs.second ),
240242 reinterpret_cast <const char *>(inputs.first ->const_data_ptr ()),
@@ -252,15 +254,15 @@ inline void apply_elementwise_fn_generic_impl(
252254 out.numel(),
253255 ::executorch::extension::internal::GRAIN_SIZE,
254256 [&](const auto begin, const auto end) {
255- const auto range = BroadcastIndexesRange<
256- sizeof ...(inputs),
257- support_noncontiguous_tensors>( out, (*inputs.first )...);
257+ const auto range =
258+ BroadcastIndexesRange< kNumInputs , support_noncontiguous_tensors>(
259+ out, (*inputs.first )...);
258260 auto begin_it = range.begin ();
259261 begin_it += begin;
260262 for (; (*begin_it)[0 ] < end; ++begin_it) {
261263 const auto & indexes = *begin_it;
262- std::array<CTYPE_COMPUTE, sizeof ...(inputs) > loaded_inputs{};
263- for (const auto idx : c10::irange (sizeof ...(inputs) )) {
264+ std::array<CTYPE_COMPUTE, kNumInputs > loaded_inputs{};
265+ for (const auto idx : c10::irange (kNumInputs )) {
264266 const auto & input_info = inputs_info[idx];
265267 loaded_inputs[idx] = input_info.load_to_compute (
266268 &input_info
0 commit comments