@@ -347,6 +347,10 @@ class ApplyOverDimListPlan {
347347 return ;
348348 }
349349 dim_list_ = dim_list.value ();
350+ if (dim_list_.value ().size () == 1 ) {
351+ mode_ = ExecutionMode::OnlyOneDim;
352+ return ;
353+ }
350354 is_in_dim_list_.fill (0 );
351355 for (const auto & d : dim_list.value ()) {
352356 const size_t non_neg_d = d < 0 ? d + in.dim () : d;
@@ -367,6 +371,14 @@ class ApplyOverDimListPlan {
367371 apply_on_flat_ix_with_stride_and_base (
368372 fn, /* stride=*/ 1 , /* base=*/ 0 , ustart_, uend_);
369373 return ;
374+ case ExecutionMode::OnlyOneDim:
375+ apply_on_flat_and_dim_ix_with_stride_and_base (
376+ [&](const auto in_ix, const auto dim_ix) { fn (in_ix); },
377+ in_.strides ()[ET_NORMALIZE_IX (dim_list_.value ()[0 ], in_.dim ())],
378+ get_init_index (in_, dim_list_.value (), out_ix),
379+ ustart_,
380+ uend_);
381+ return ;
370382 case ExecutionMode::NormalDimMask:
371383 apply_on_flat_ix_with_dim_mask_and_base (
372384 fn,
@@ -399,6 +411,9 @@ class ApplyOverDimListPlan {
399411 // Iterate over the entire tensor with
400412 // apply_on_flat_ix_with_stride_and_base.
401413 NoDimMaskOrZeroDimension,
414+ // dim_list has size 1, iterate with
415+ // apply_on_flat_and_dim_ix_with_stride_and_base
416+ OnlyOneDim,
402417 // General mode, iterate with
403418 // apply_on_flat_ix_with_dim_mask_and_base.
404419 NormalDimMask
0 commit comments