@@ -629,30 +629,33 @@ class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
629629 bool modified = succeeded (foldDynamicIndexList (mixedHalos, true )) ||
630630 succeeded (foldDynamicIndexList (mixedOffs, true ));
631631
632- auto halos = decomposeMixedValues (mixedHalos);
633- auto offs = decomposeMixedValues (mixedOffs);
632+ auto [staticHalos, dynamicHalos] = decomposeMixedValues (mixedHalos);
633+ auto [staticOffs, dynamicOffs] = decomposeMixedValues (mixedOffs);
634634
635- if (halos. second . empty () && !halos. first .empty ()) {
636- if (halos. first [0 ] == 0 && llvm::all_equal (halos. first )) {
637- halos. first .clear ();
635+ if (dynamicHalos. empty () && !staticHalos .empty ()) {
636+ if (staticHalos [0 ] == 0 && llvm::all_equal (staticHalos )) {
637+ staticHalos .clear ();
638638 modified = true ;
639639 }
640640 }
641641
642642 // Remove sharded dims offsets if they are effectively the default values,
643643 // e.g. if they define equi-distance between all neighboring shards.
644- if (offs.second .empty () && !offs.first .empty ()) {
645- assert (offs.first .size () >= 2 );
646- auto diff = offs.first [1 ] - offs.first [0 ];
647- bool all_same = offs.first .size () > 2 ;
648- for (auto i = 2u ; i < offs.first .size (); ++i) {
649- if (offs.first [i] - offs.first [i - 1 ] != diff) {
644+ // Requires static-only offsets. Compares the first distance as the
645+ // difference between the first two offsets. Only if all consecutive
646+ // distances are the same, the offsets are removed.
647+ if (dynamicOffs.empty () && !staticOffs.empty ()) {
648+ assert (staticOffs.size () >= 2 );
649+ auto diff = staticOffs[1 ] - staticOffs[0 ];
650+ bool all_same = staticOffs.size () > 2 ;
651+ for (auto i = 2u ; i < staticOffs.size (); ++i) {
652+ if (staticOffs[i] - staticOffs[i - 1 ] != diff) {
650653 all_same = false ;
651654 break ;
652655 }
653656 }
654657 if (all_same) {
655- offs. first .clear ();
658+ staticOffs .clear ();
656659 modified = true ;
657660 }
658661 }
@@ -661,10 +664,10 @@ class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
661664 return failure ();
662665 }
663666
664- op.setStaticHaloSizes (halos. first );
665- op.getDynamicHaloSizesMutable ().assign (halos. second );
666- op.setStaticShardedDimsOffsets (offs. first );
667- op.getDynamicShardedDimsOffsetsMutable ().assign (offs. second );
667+ op.setStaticHaloSizes (staticHalos );
668+ op.getDynamicHaloSizesMutable ().assign (dynamicHalos );
669+ op.setStaticShardedDimsOffsets (staticOffs );
670+ op.getDynamicShardedDimsOffsetsMutable ().assign (dynamicOffs );
668671
669672 return success ();
670673 }
0 commit comments