Skip to content

Commit 508095a

Browse files
committed
comments
adding libs clang-format renaming mesh-spmdize.cpp -> mesh-spmdize.mlir and fixing format
1 parent 3c76df3 commit 508095a

File tree

8 files changed

+40
-36
lines changed

8 files changed

+40
-36
lines changed

mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ add_mlir_dialect_library(MLIRArithTransforms
2727
MLIRInferIntRangeInterface
2828
MLIRIR
2929
MLIRMemRefDialect
30+
MLIRMeshDialect
3031
MLIRPass
32+
MLIRShardingInterface
3133
MLIRTensorDialect
3234
MLIRTransforms
3335
MLIRTransformUtils

mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ using namespace mlir::mesh;
1919

2020
namespace {
2121

22-
// Sharding of arith.empty/arith.splat
22+
// Sharding of arith.constant
2323
struct ConstantShardingInterface
2424
: public ShardingInterface::ExternalModel<ConstantShardingInterface,
2525
ConstantOp> {

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
286286
ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
287287
if (shardOp && sharding == shardOp.getSharding() &&
288288
!shardOp.getAnnotateForUsers()) {
289-
// No need for anything the correct sharding is already set.
289+
// No need for anything if the correct sharding is already set.
290290
return newShardOp ? newShardOp : shardOp;
291291
}
292292

@@ -639,6 +639,8 @@ class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
639639
}
640640
}
641641

642+
// Remove sharded dims offsets if they are effectively the default values,
643+
// e.g. if they define equi-distance between all neighboring shards.
642644
if (offs.second.empty() && !offs.first.empty()) {
643645
assert(offs.first.size() >= 2);
644646
auto diff = offs.first[1] - offs.first[0];
@@ -772,7 +774,8 @@ MeshSharding::MeshSharding(Value rhs) {
772774
assert(shardingOp && "expected sharding op");
773775
auto splitAxes = shardingOp.getSplitAxes().getAxes();
774776
auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>());
775-
if(splitAxes.empty() && partialAxes.empty()) {
777+
// If splitAxes and partialAxes are empty, use "empty" constructor.
778+
if (splitAxes.empty() && partialAxes.empty()) {
776779
*this = MeshSharding(shardingOp.getMeshAttr());
777780
return;
778781
}
@@ -793,7 +796,7 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
793796
ArrayRef<Value> dynamic_halo_sizes_,
794797
ArrayRef<Value> dynamic_sharded_dims_offsets_) {
795798
MeshSharding res(mesh_);
796-
if(split_axes_.empty() && partial_axes_.empty()) {
799+
if (split_axes_.empty() && partial_axes_.empty()) {
797800
return res;
798801
}
799802

mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,6 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
174174
if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
175175
return failure();
176176

177-
// check loop types
178-
// SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes();
179-
// if (loopTypes.empty())
180-
// return failure();
181-
182177
// check maps
183178
SmallVector<AffineMap> maps = getIndexingMaps();
184179
if (maps.empty())
@@ -453,8 +448,8 @@ static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
453448
AffineMap map) {
454449
Value operandValue = opOperand.get();
455450
auto operandType = dyn_cast<RankedTensorType>(operandValue.getType());
456-
if(!operandType) {
457-
if(operandValue.getType().isIntOrIndexOrFloat())
451+
if (!operandType) {
452+
if (operandValue.getType().isIntOrIndexOrFloat())
458453
return MeshSharding();
459454
return failure();
460455
}

mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ static std::vector<MeshSharding> getResultShardings(Operation &op) {
690690
std::vector<MeshSharding> res;
691691
res.reserve(op.getNumResults());
692692
llvm::transform(op.getResults(), std::back_inserter(res),
693-
[&op](OpResult result) {
693+
[](OpResult result) {
694694
TypedValue<RankedTensorType> rankedTensor =
695695
dyn_cast<TypedValue<RankedTensorType>>(result);
696696
if (!rankedTensor) {

mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ using namespace mlir::mesh;
2323
namespace {
2424

2525
// Sharding of tensor.empty/tensor.splat
26-
template<typename OpTy>
26+
template <typename OpTy>
2727
struct CreatorOpShardingInterface
28-
: public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>, OpTy> {
28+
: public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>,
29+
OpTy> {
2930
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
3031
auto ndims = mlir::cast<ShapedType>(op->getResult(0).getType()).getRank();
3132
return SmallVector<utils::IteratorType>(ndims,
@@ -38,7 +39,9 @@ struct CreatorOpShardingInterface
3839
auto type = dyn_cast<RankedTensorType>(val.getType());
3940
if (!type)
4041
return {};
41-
return SmallVector<AffineMap>(op->getNumOperands() + op->getNumResults(), {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
42+
return SmallVector<AffineMap>(
43+
op->getNumOperands() + op->getNumResults(),
44+
{AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
4245
}
4346

4447
LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
@@ -82,8 +85,7 @@ struct CreatorOpShardingInterface
8285
newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]);
8386
}
8487
}
85-
newOp =
86-
builder.create<OpTy>(op->getLoc(), shardType, newOperands);
88+
newOp = builder.create<OpTy>(op->getLoc(), shardType, newOperands);
8789
spmdizationMap.map(op->getResult(0), newOp->getResult(0));
8890
} else {
8991
// `clone` will populate the mapping of old to new results.
@@ -100,7 +102,9 @@ void mlir::tensor::registerShardingInterfaceExternalModels(
100102
DialectRegistry &registry) {
101103

102104
registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
103-
EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(*ctx);
104-
SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(*ctx);
105+
EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(
106+
*ctx);
107+
SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(
108+
*ctx);
105109
});
106110
}

mlir/test/Dialect/Arith/mesh-spmdize.cpp

Lines changed: 0 additions & 17 deletions
This file was deleted.
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: mlir-opt \
2+
// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization))" \
3+
// RUN: %s | FileCheck %s
4+
5+
mesh.mesh @mesh4x4(shape = 4x4)
6+
7+
// CHECK-LABEL: func @test_spmdize_constant
8+
// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> :
9+
// tensor<256x1024xf32> CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 :
10+
// i32 CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32>
11+
func.func @test_spmdize_constant() ->(tensor<1024x1024xf32>)attributes{llvm.emit_c_interface} {
12+
%cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
13+
%sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
14+
%sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32>
15+
%ci = arith.constant 434 : i32
16+
return %sharding_annotated_1 : tensor<1024x1024xf32>
17+
}

0 commit comments

Comments
 (0)