Skip to content

Commit 6049289

Browse files
committed
[MLIR] Apply clang-tidy fixes for readability-identifier-naming in ShardOps.cpp (NFC)
1 parent 1bbff72 commit 6049289

File tree

1 file changed

+39
-41
lines changed

1 file changed

+39
-41
lines changed

mlir/lib/Dialect/Shard/IR/ShardOps.cpp

Lines changed: 39 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -476,38 +476,37 @@ void GridShapeOp::getAsmResultNames(
476476
//===----------------------------------------------------------------------===//
477477

478478
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
479-
FlatSymbolRefAttr grid,
480-
ArrayRef<GridAxesAttr> split_axes,
481-
ArrayRef<int64_t> static_halos,
482-
ArrayRef<int64_t> static_offsets) {
479+
FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> splitAxes,
480+
ArrayRef<int64_t> staticHalos,
481+
ArrayRef<int64_t> staticOffsets) {
483482
return build(
484-
b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes),
485-
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
486-
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
483+
b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), splitAxes),
484+
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), {},
485+
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticOffsets), {});
487486
}
488487

489488
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
490-
llvm::StringRef grid, ArrayRef<GridAxesAttr> split_axes,
491-
ArrayRef<int64_t> static_halos,
492-
ArrayRef<int64_t> static_offsets) {
489+
llvm::StringRef grid, ArrayRef<GridAxesAttr> splitAxes,
490+
ArrayRef<int64_t> staticHalos,
491+
ArrayRef<int64_t> staticOffsets) {
493492
return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), grid),
494-
GridAxesArrayAttr::get(b.getContext(), split_axes),
495-
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
496-
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets),
493+
GridAxesArrayAttr::get(b.getContext(), splitAxes),
494+
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), {},
495+
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticOffsets),
497496
{});
498497
}
499498

500499
void ShardingOp::build(
501500
::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
502-
FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> split_axes,
503-
::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
504-
::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
501+
FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> splitAxes,
502+
::mlir::ArrayRef<::mlir::OpFoldResult> haloSizes,
503+
::mlir::ArrayRef<::mlir::OpFoldResult> shardedDimsOffsets) {
505504
mlir::SmallVector<int64_t> staticHalos, staticDims;
506505
mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
507-
dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
508-
dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
506+
dispatchIndexOpFoldResults(haloSizes, dynamicHalos, staticHalos);
507+
dispatchIndexOpFoldResults(shardedDimsOffsets, dynamicDims, staticDims);
509508
return build(
510-
b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes),
509+
b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), splitAxes),
511510
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
512511
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
513512
}
@@ -650,14 +649,14 @@ class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
650649
if (dynamicOffs.empty() && !staticOffs.empty()) {
651650
assert(staticOffs.size() >= 2);
652651
auto diff = staticOffs[1] - staticOffs[0];
653-
bool all_same = staticOffs.size() > 2;
652+
bool allSame = staticOffs.size() > 2;
654653
for (auto i = 2u; i < staticOffs.size(); ++i) {
655654
if (staticOffs[i] - staticOffs[i - 1] != diff) {
656-
all_same = false;
655+
allSame = false;
657656
break;
658657
}
659658
}
660-
if (all_same) {
659+
if (allSame) {
661660
staticOffs.clear();
662661
modified = true;
663662
}
@@ -749,7 +748,7 @@ bool Sharding::operator==(const Sharding &rhs) const {
749748

750749
bool Sharding::operator!=(const Sharding &rhs) const { return !(*this == rhs); }
751750

752-
Sharding::Sharding(::mlir::FlatSymbolRefAttr grid_) : grid(grid_) {}
751+
Sharding::Sharding(::mlir::FlatSymbolRefAttr grid) : grid(grid) {}
753752

754753
Sharding::Sharding(Value rhs) {
755754
auto shardingOp = rhs.getDefiningOp<ShardingOp>();
@@ -767,32 +766,31 @@ Sharding::Sharding(Value rhs) {
767766
SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
768767
}
769768

770-
Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid_,
771-
ArrayRef<GridAxesAttr> split_axes_,
772-
ArrayRef<int64_t> static_halo_sizes_,
773-
ArrayRef<int64_t> static_sharded_dims_offsets_,
774-
ArrayRef<Value> dynamic_halo_sizes_,
775-
ArrayRef<Value> dynamic_sharded_dims_offsets_) {
776-
Sharding res(grid_);
777-
if (split_axes_.empty()) {
769+
Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid,
770+
ArrayRef<GridAxesAttr> splitAxes,
771+
ArrayRef<int64_t> staticHaloSizes,
772+
ArrayRef<int64_t> staticShardedDimsOffsets,
773+
ArrayRef<Value> dynamicHaloSizes,
774+
ArrayRef<Value> dynamicShardedDimsOffsets) {
775+
Sharding res(grid);
776+
if (splitAxes.empty()) {
778777
return res;
779778
}
780779

781-
res.split_axes.resize(split_axes_.size());
782-
for (auto [i, axis] : llvm::enumerate(split_axes_)) {
783-
res.split_axes[i] =
784-
GridAxesAttr::get(grid_.getContext(), axis.asArrayRef());
780+
res.split_axes.resize(splitAxes.size());
781+
for (auto [i, axis] : llvm::enumerate(splitAxes)) {
782+
res.split_axes[i] = GridAxesAttr::get(grid.getContext(), axis.asArrayRef());
785783
}
786784

787785
auto clone = [](const auto src, auto &dst) {
788786
dst.resize(src.size());
789787
llvm::copy(src, dst.begin());
790788
};
791789

792-
clone(static_halo_sizes_, res.static_halo_sizes);
793-
clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
794-
clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
795-
clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
790+
clone(staticHaloSizes, res.static_halo_sizes);
791+
clone(staticShardedDimsOffsets, res.static_sharded_dims_offsets);
792+
clone(dynamicHaloSizes, res.dynamic_halo_sizes);
793+
clone(dynamicShardedDimsOffsets, res.dynamic_sharded_dims_offsets);
796794

797795
return res;
798796
}
@@ -809,10 +807,10 @@ void ShardShapeOp::getAsmResultNames(
809807
void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
810808
::mlir::OperationState &odsState,
811809
::llvm::ArrayRef<int64_t> dims,
812-
ArrayRef<Value> dims_dyn, ::mlir::Value sharding,
810+
ArrayRef<Value> dimsDyn, ::mlir::Value sharding,
813811
::mlir::ValueRange device) {
814812
SmallVector<mlir::Type> resType(dims.size(), odsBuilder.getIndexType());
815-
build(odsBuilder, odsState, resType, dims, dims_dyn, sharding,
813+
build(odsBuilder, odsState, resType, dims, dimsDyn, sharding,
816814
SmallVector<int64_t>(device.size(), ShapedType::kDynamic), device);
817815
}
818816

0 commit comments

Comments
 (0)