Skip to content

Commit 99cf24e

Browse files
committed
comments and nicer code (from review)
1 parent e8ccad1 commit 99cf24e

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -629,30 +629,33 @@ class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
629629
bool modified = succeeded(foldDynamicIndexList(mixedHalos, true)) ||
630630
succeeded(foldDynamicIndexList(mixedOffs, true));
631631

632-
auto halos = decomposeMixedValues(mixedHalos);
633-
auto offs = decomposeMixedValues(mixedOffs);
632+
auto [staticHalos, dynamicHalos] = decomposeMixedValues(mixedHalos);
633+
auto [staticOffs, dynamicOffs] = decomposeMixedValues(mixedOffs);
634634

635-
if (halos.second.empty() && !halos.first.empty()) {
636-
if (halos.first[0] == 0 && llvm::all_equal(halos.first)) {
637-
halos.first.clear();
635+
if (dynamicHalos.empty() && !staticHalos.empty()) {
636+
if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
637+
staticHalos.clear();
638638
modified = true;
639639
}
640640
}
641641

642642
// Remove sharded dims offsets if they are effectively the default values,
643643
// e.g. if they define equi-distance between all neighboring shards.
644-
if (offs.second.empty() && !offs.first.empty()) {
645-
assert(offs.first.size() >= 2);
646-
auto diff = offs.first[1] - offs.first[0];
647-
bool all_same = offs.first.size() > 2;
648-
for (auto i = 2u; i < offs.first.size(); ++i) {
649-
if (offs.first[i] - offs.first[i - 1] != diff) {
644+
// Requires static-only offsets. Compares the first distance as the
645+
// difference between the first two offsets. Only if all consecutive
646+
// distances are the same, the offsets are removed.
647+
if (dynamicOffs.empty() && !staticOffs.empty()) {
648+
assert(staticOffs.size() >= 2);
649+
auto diff = staticOffs[1] - staticOffs[0];
650+
bool all_same = staticOffs.size() > 2;
651+
for (auto i = 2u; i < staticOffs.size(); ++i) {
652+
if (staticOffs[i] - staticOffs[i - 1] != diff) {
650653
all_same = false;
651654
break;
652655
}
653656
}
654657
if (all_same) {
655-
offs.first.clear();
658+
staticOffs.clear();
656659
modified = true;
657660
}
658661
}
@@ -661,10 +664,10 @@ class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
661664
return failure();
662665
}
663666

664-
op.setStaticHaloSizes(halos.first);
665-
op.getDynamicHaloSizesMutable().assign(halos.second);
666-
op.setStaticShardedDimsOffsets(offs.first);
667-
op.getDynamicShardedDimsOffsetsMutable().assign(offs.second);
667+
op.setStaticHaloSizes(staticHalos);
668+
op.getDynamicHaloSizesMutable().assign(dynamicHalos);
669+
op.setStaticShardedDimsOffsets(staticOffs);
670+
op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
668671

669672
return success();
670673
}

mlir/test/Dialect/Mesh/ops.mlir

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,6 @@ func.func @mesh_shard_shape() {
167167
// CHECK-LABEL: func @mesh_get_sharding
168168
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
169169
func.func @mesh_get_sharding(%arg0 : tensor<4x8xf32>) -> !mesh.sharding {
170-
// CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh1 split_axes = {{\[\[}}], [0]] : !mesh.sharding
171-
%s = mesh.sharding @mesh1 split_axes = [[], [0]] : !mesh.sharding
172170
// CHECK-NEXT: mesh.get_sharding %[[ARG]] : tensor<4x8xf32> -> !mesh.sharding
173171
%0 = mesh.get_sharding %arg0 : tensor<4x8xf32> -> !mesh.sharding
174172
return %0 : !mesh.sharding

0 commit comments

Comments
 (0)