@@ -479,37 +479,23 @@ void MeshShapeOp::getAsmResultNames(
479479void ShardingOp::build (::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
480480 FlatSymbolRefAttr mesh,
481481 ArrayRef<MeshAxesAttr> split_axes,
482- ArrayRef<MeshAxis> partial_axes,
483- mesh::ReductionKind partial_type,
484482 ArrayRef<int64_t > static_halos,
485483 ArrayRef<int64_t > static_offsets) {
486484 return build (
487485 b, odsState, mesh, MeshAxesArrayAttr::get (b.getContext (), split_axes),
488- ::mlir::DenseI16ArrayAttr::get (b.getContext(), partial_axes),
489- ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
490486 ::mlir::DenseI64ArrayAttr::get (b.getContext(), static_halos), {},
491487 ::mlir::DenseI64ArrayAttr::get (b.getContext(), static_offsets), {});
492488}
493489
494- void ShardingOp::build (::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
495- FlatSymbolRefAttr mesh,
496- ArrayRef<MeshAxesAttr> split_axes) {
497- return build (
498- b, odsState, mesh, MeshAxesArrayAttr::get (b.getContext (), split_axes), {},
499- ::mlir::mesh::ReductionKindAttr::get (b.getContext(), ReductionKind::Sum),
500- {}, {}, {}, {});
501- }
502-
503490void ShardingOp::build (::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
504491 llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
505492 ArrayRef<int64_t > static_halos,
506493 ArrayRef<int64_t > static_offsets) {
507- return build (
508- b, odsState, FlatSymbolRefAttr::get (b.getContext (), mesh),
509- MeshAxesArrayAttr::get (b.getContext (), split_axes), {},
510- ::mlir::mesh::ReductionKindAttr::get (b.getContext(), ReductionKind::Sum),
511- ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
512- ::mlir::DenseI64ArrayAttr::get (b.getContext(), static_offsets), {});
494+ return build (b, odsState, FlatSymbolRefAttr::get (b.getContext (), mesh),
495+ MeshAxesArrayAttr::get (b.getContext (), split_axes),
496+ ::mlir::DenseI64ArrayAttr::get (b.getContext(), static_halos), {},
497+ ::mlir::DenseI64ArrayAttr::get (b.getContext(), static_offsets),
498+ {});
513499}
514500
515501void ShardingOp::build (
@@ -522,8 +508,7 @@ void ShardingOp::build(
522508 dispatchIndexOpFoldResults (halo_sizes, dynamicHalos, staticHalos);
523509 dispatchIndexOpFoldResults (sharded_dims_offsets, dynamicDims, staticDims);
524510 return build (
525- b, odsState, mesh, MeshAxesArrayAttr::get (b.getContext (), split_axes), {},
526- ::mlir::mesh::ReductionKindAttr::get (b.getContext(), ReductionKind::Sum),
511+ b, odsState, mesh, MeshAxesArrayAttr::get (b.getContext (), split_axes),
527512 ::mlir::DenseI64ArrayAttr::get (b.getContext(), staticHalos), dynamicHalos,
528513 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
529514}
@@ -533,11 +518,6 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
533518
534519 build (b, odsState, ShardingType::get (b.getContext ()), from.getMeshAttr (),
535520 MeshAxesArrayAttr::get (b.getContext (), from.getSplitAxes ()),
536- from.getPartialAxes ().empty ()
537- ? DenseI16ArrayAttr ()
538- : b.getDenseI16ArrayAttr (from.getPartialAxes ()),
539- ::mlir::mesh::ReductionKindAttr::get (b.getContext(),
540- from.getPartialType()),
541521 from.getStaticShardedDimsOffsets ().empty ()
542522 ? DenseI64ArrayAttr ()
543523 : b.getDenseI64ArrayAttr (from.getStaticShardedDimsOffsets ()),
@@ -566,9 +546,6 @@ LogicalResult ShardingOp::verify() {
566546 if (failed (checkMeshAxis (subAxesArray)))
567547 return failure ();
568548 }
569- if (getPartialAxes ().has_value () &&
570- failed (checkMeshAxis (getPartialAxes ().value ())))
571- return failure ();
572549
573550 if (!getStaticHaloSizes ().empty () && !getStaticShardedDimsOffsets ().empty ()) {
574551 return emitOpError (" halo sizes and shard offsets are mutually exclusive" );
@@ -710,17 +687,11 @@ void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
710687// MeshSharding
711688// ===----------------------------------------------------------------------===//
712689
713- bool MeshSharding::equalSplitAndPartialAxes (const MeshSharding &rhs) const {
690+ bool MeshSharding::equalSplitAxes (const MeshSharding &rhs) const {
714691 if (getMesh () != rhs.getMesh ()) {
715692 return false ;
716693 }
717694
718- if (getPartialAxes ().size () != rhs.getPartialAxes ().size () ||
719- (!getPartialAxes ().empty () && getPartialType () != rhs.getPartialType ()) ||
720- !llvm::equal (getPartialAxes (), rhs.getPartialAxes ())) {
721- return false ;
722- }
723-
724695 auto minSize = std::min (getSplitAxes ().size (), rhs.getSplitAxes ().size ());
725696 if (!llvm::equal (llvm::make_range (getSplitAxes ().begin (),
726697 getSplitAxes ().begin () + minSize),
@@ -768,13 +739,13 @@ bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const {
768739}
769740
770741bool MeshSharding::operator ==(Value rhs) const {
771- return equalSplitAndPartialAxes (rhs) && equalHaloAndShardSizes (rhs);
742+ return equalSplitAxes (rhs) && equalHaloAndShardSizes (rhs);
772743}
773744
774745bool MeshSharding::operator !=(Value rhs) const { return !(*this == rhs); }
775746
776747bool MeshSharding::operator ==(const MeshSharding &rhs) const {
777- return equalSplitAndPartialAxes (rhs) && equalHaloAndShardSizes (rhs);
748+ return equalSplitAxes (rhs) && equalHaloAndShardSizes (rhs);
778749}
779750
780751bool MeshSharding::operator !=(const MeshSharding &rhs) const {
@@ -787,30 +758,26 @@ MeshSharding::MeshSharding(Value rhs) {
787758 auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp ());
788759 assert (shardingOp && " expected sharding op" );
789760 auto splitAxes = shardingOp.getSplitAxes ().getAxes ();
790- auto partialAxes = shardingOp.getPartialAxes ().value_or (ArrayRef<MeshAxis>());
791- // If splitAxes and partialAxes are empty, use "empty" constructor.
792- if (splitAxes.empty () && partialAxes.empty ()) {
761+ // If splitAxes are empty, use "empty" constructor.
762+ if (splitAxes.empty ()) {
793763 *this = MeshSharding (shardingOp.getMeshAttr ());
794764 return ;
795765 }
796- *this = get (shardingOp.getMeshAttr (), splitAxes, partialAxes,
797- shardingOp.getPartialType ().value_or (ReductionKind::Sum),
798- shardingOp.getStaticHaloSizes (),
799- shardingOp.getStaticShardedDimsOffsets (),
800- SmallVector<Value>(shardingOp.getDynamicHaloSizes ()),
801- SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets ()));
766+ *this =
767+ get (shardingOp.getMeshAttr (), splitAxes, shardingOp.getStaticHaloSizes (),
768+ shardingOp.getStaticShardedDimsOffsets (),
769+ SmallVector<Value>(shardingOp.getDynamicHaloSizes ()),
770+ SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets ()));
802771}
803772
804773MeshSharding MeshSharding::get (::mlir::FlatSymbolRefAttr mesh_,
805774 ArrayRef<MeshAxesAttr> split_axes_,
806- ArrayRef<MeshAxis> partial_axes_,
807- ReductionKind partial_type_,
808775 ArrayRef<int64_t > static_halo_sizes_,
809776 ArrayRef<int64_t > static_sharded_dims_offsets_,
810777 ArrayRef<Value> dynamic_halo_sizes_,
811778 ArrayRef<Value> dynamic_sharded_dims_offsets_) {
812779 MeshSharding res (mesh_);
813- if (split_axes_.empty () && partial_axes_. empty () ) {
780+ if (split_axes_.empty ()) {
814781 return res;
815782 }
816783
@@ -825,8 +792,6 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
825792 llvm::copy (src, dst.begin ());
826793 };
827794
828- clone (partial_axes_, res.partial_axes );
829- res.partial_type = partial_type_;
830795 clone (static_halo_sizes_, res.static_halo_sizes );
831796 clone (static_sharded_dims_offsets_, res.static_sharded_dims_offsets );
832797 clone (dynamic_halo_sizes_, res.dynamic_halo_sizes );
0 commit comments