@@ -476,38 +476,37 @@ void GridShapeOp::getAsmResultNames(
476476// ===----------------------------------------------------------------------===//
477477
478478void ShardingOp::build (::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
479- FlatSymbolRefAttr grid,
480- ArrayRef<GridAxesAttr> split_axes,
481- ArrayRef<int64_t > static_halos,
482- ArrayRef<int64_t > static_offsets) {
479+ FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> splitAxes,
480+ ArrayRef<int64_t > staticHalos,
481+ ArrayRef<int64_t > staticOffsets) {
483482 return build (
484- b, odsState, grid, GridAxesArrayAttr::get (b.getContext (), split_axes ),
485- ::mlir::DenseI64ArrayAttr::get (b.getContext(), static_halos ), {},
486- ::mlir::DenseI64ArrayAttr::get (b.getContext(), static_offsets ), {});
483+ b, odsState, grid, GridAxesArrayAttr::get (b.getContext (), splitAxes ),
484+ ::mlir::DenseI64ArrayAttr::get (b.getContext(), staticHalos ), {},
485+ ::mlir::DenseI64ArrayAttr::get (b.getContext(), staticOffsets ), {});
487486}
488487
489488void ShardingOp::build (::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
490- llvm::StringRef grid, ArrayRef<GridAxesAttr> split_axes ,
491- ArrayRef<int64_t > static_halos ,
492- ArrayRef<int64_t > static_offsets ) {
489+ llvm::StringRef grid, ArrayRef<GridAxesAttr> splitAxes ,
490+ ArrayRef<int64_t > staticHalos ,
491+ ArrayRef<int64_t > staticOffsets ) {
493492 return build (b, odsState, FlatSymbolRefAttr::get (b.getContext (), grid),
494- GridAxesArrayAttr::get (b.getContext (), split_axes ),
495- ::mlir::DenseI64ArrayAttr::get (b.getContext(), static_halos ), {},
496- ::mlir::DenseI64ArrayAttr::get (b.getContext(), static_offsets ),
493+ GridAxesArrayAttr::get (b.getContext (), splitAxes ),
494+ ::mlir::DenseI64ArrayAttr::get (b.getContext(), staticHalos ), {},
495+ ::mlir::DenseI64ArrayAttr::get (b.getContext(), staticOffsets ),
497496 {});
498497}
499498
500499void ShardingOp::build (
501500 ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
502- FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> split_axes ,
503- ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes ,
504- ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets ) {
501+ FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> splitAxes ,
502+ ::mlir::ArrayRef<::mlir::OpFoldResult> haloSizes ,
503+ ::mlir::ArrayRef<::mlir::OpFoldResult> shardedDimsOffsets ) {
505504 mlir::SmallVector<int64_t > staticHalos, staticDims;
506505 mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
507- dispatchIndexOpFoldResults (halo_sizes , dynamicHalos, staticHalos);
508- dispatchIndexOpFoldResults (sharded_dims_offsets , dynamicDims, staticDims);
506+ dispatchIndexOpFoldResults (haloSizes , dynamicHalos, staticHalos);
507+ dispatchIndexOpFoldResults (shardedDimsOffsets , dynamicDims, staticDims);
509508 return build (
510- b, odsState, grid, GridAxesArrayAttr::get (b.getContext (), split_axes ),
509+ b, odsState, grid, GridAxesArrayAttr::get (b.getContext (), splitAxes ),
511510 ::mlir::DenseI64ArrayAttr::get (b.getContext(), staticHalos), dynamicHalos,
512511 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
513512}
@@ -650,14 +649,14 @@ class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
650649 if (dynamicOffs.empty () && !staticOffs.empty ()) {
651650 assert (staticOffs.size () >= 2 );
652651 auto diff = staticOffs[1 ] - staticOffs[0 ];
653- bool all_same = staticOffs.size () > 2 ;
652+ bool allSame = staticOffs.size () > 2 ;
654653 for (auto i = 2u ; i < staticOffs.size (); ++i) {
655654 if (staticOffs[i] - staticOffs[i - 1 ] != diff) {
656- all_same = false ;
655+ allSame = false ;
657656 break ;
658657 }
659658 }
660- if (all_same ) {
659+ if (allSame ) {
661660 staticOffs.clear ();
662661 modified = true ;
663662 }
@@ -749,7 +748,7 @@ bool Sharding::operator==(const Sharding &rhs) const {
749748
750749bool Sharding::operator !=(const Sharding &rhs) const { return !(*this == rhs); }
751750
752- Sharding::Sharding (::mlir::FlatSymbolRefAttr grid_ ) : grid(grid_ ) {}
751+ Sharding::Sharding (::mlir::FlatSymbolRefAttr grid ) : grid(grid ) {}
753752
754753Sharding::Sharding (Value rhs) {
755754 auto shardingOp = rhs.getDefiningOp <ShardingOp>();
@@ -767,32 +766,31 @@ Sharding::Sharding(Value rhs) {
767766 SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets ()));
768767}
769768
770- Sharding Sharding::get (::mlir::FlatSymbolRefAttr grid_ ,
771- ArrayRef<GridAxesAttr> split_axes_ ,
772- ArrayRef<int64_t > static_halo_sizes_ ,
773- ArrayRef<int64_t > static_sharded_dims_offsets_ ,
774- ArrayRef<Value> dynamic_halo_sizes_ ,
775- ArrayRef<Value> dynamic_sharded_dims_offsets_ ) {
776- Sharding res (grid_ );
777- if (split_axes_ .empty ()) {
769+ Sharding Sharding::get (::mlir::FlatSymbolRefAttr grid ,
770+ ArrayRef<GridAxesAttr> splitAxes ,
771+ ArrayRef<int64_t > staticHaloSizes ,
772+ ArrayRef<int64_t > staticShardedDimsOffsets ,
773+ ArrayRef<Value> dynamicHaloSizes ,
774+ ArrayRef<Value> dynamicShardedDimsOffsets ) {
775+ Sharding res (grid );
776+ if (splitAxes .empty ()) {
778777 return res;
779778 }
780779
781- res.split_axes .resize (split_axes_.size ());
782- for (auto [i, axis] : llvm::enumerate (split_axes_)) {
783- res.split_axes [i] =
784- GridAxesAttr::get (grid_.getContext (), axis.asArrayRef ());
780+ res.split_axes .resize (splitAxes.size ());
781+ for (auto [i, axis] : llvm::enumerate (splitAxes)) {
782+ res.split_axes [i] = GridAxesAttr::get (grid.getContext (), axis.asArrayRef ());
785783 }
786784
787785 auto clone = [](const auto src, auto &dst) {
788786 dst.resize (src.size ());
789787 llvm::copy (src, dst.begin ());
790788 };
791789
792- clone (static_halo_sizes_ , res.static_halo_sizes );
793- clone (static_sharded_dims_offsets_ , res.static_sharded_dims_offsets );
794- clone (dynamic_halo_sizes_ , res.dynamic_halo_sizes );
795- clone (dynamic_sharded_dims_offsets_ , res.dynamic_sharded_dims_offsets );
790+ clone (staticHaloSizes , res.static_halo_sizes );
791+ clone (staticShardedDimsOffsets , res.static_sharded_dims_offsets );
792+ clone (dynamicHaloSizes , res.dynamic_halo_sizes );
793+ clone (dynamicShardedDimsOffsets , res.dynamic_sharded_dims_offsets );
796794
797795 return res;
798796}
@@ -809,10 +807,10 @@ void ShardShapeOp::getAsmResultNames(
809807void ShardShapeOp::build (::mlir::OpBuilder &odsBuilder,
810808 ::mlir::OperationState &odsState,
811809 ::llvm::ArrayRef<int64_t > dims,
812- ArrayRef<Value> dims_dyn , ::mlir::Value sharding,
810+ ArrayRef<Value> dimsDyn , ::mlir::Value sharding,
813811 ::mlir::ValueRange device) {
814812 SmallVector<mlir::Type> resType (dims.size (), odsBuilder.getIndexType ());
815- build (odsBuilder, odsState, resType, dims, dims_dyn , sharding,
813+ build (odsBuilder, odsState, resType, dims, dimsDyn , sharding,
816814 SmallVector<int64_t >(device.size (), ShapedType::kDynamic ), device);
817815}
818816
0 commit comments