@@ -327,6 +327,10 @@ class ApplyOverDimListPlan {
327327 return ;
328328 }
329329 dim_list_ = dim_list.value ();
330+ if (dim_list_.value ().size () == 1 ) {
331+ mode_ = ExecutionMode::OnlyOneDim;
332+ return ;
333+ }
330334 is_in_dim_list_.fill (0 );
331335 for (const auto & d : dim_list.value ()) {
332336 const size_t non_neg_d = d < 0 ? d + in.dim () : d;
@@ -347,6 +351,16 @@ class ApplyOverDimListPlan {
347351 apply_on_flat_ix_with_stride_and_base (
348352 fn, /* stride=*/ 1 , /* base=*/ 0 , ustart_, uend_);
349353 return ;
354+ case ExecutionMode::OnlyOneDim:
355+ apply_on_flat_and_dim_ix_with_stride_and_base (
356+ [&](const auto in_ix, const auto dim_ix) {
357+ fn (in_ix);
358+ },
359+ in_.strides ()[ET_NORMALIZE_IX (dim_list_.value ()[0 ], in_.dim ())],
360+ get_init_index (in_, dim_list_.value (), out_ix),
361+ ustart_,
362+ uend_);
363+ return ;
350364 case ExecutionMode::NormalDimMask:
351365 apply_on_flat_ix_with_dim_mask_and_base (
352366 fn,
@@ -379,6 +393,8 @@ class ApplyOverDimListPlan {
379393 // Iterate over the entire tensor with
380394 // apply_on_flat_ix_with_stride_and_base.
381395 NoDimMaskOrZeroDimension,
396+ // dim_list has size 1, iterate with apply_on_flat_and_dim_ix_with_stride_and_base
397+ OnlyOneDim,
382398 // General mode, iterate with
383399 // apply_on_flat_ix_with_dim_mask_and_base.
384400 NormalDimMask
0 commit comments