@@ -479,37 +479,23 @@ void MeshShapeOp::getAsmResultNames(
479
479
void ShardingOp::build (::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
480
480
FlatSymbolRefAttr mesh,
481
481
ArrayRef<MeshAxesAttr> split_axes,
482
- ArrayRef<MeshAxis> partial_axes,
483
- mesh::ReductionKind partial_type,
484
482
ArrayRef<int64_t > static_halos,
485
483
ArrayRef<int64_t > static_offsets) {
486
484
return build (
487
485
b, odsState, mesh, MeshAxesArrayAttr::get (b.getContext (), split_axes),
488
- ::mlir::DenseI16ArrayAttr::get (b.getContext(), partial_axes),
489
- ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
490
486
::mlir::DenseI64ArrayAttr::get (b.getContext(), static_halos), {},
491
487
::mlir::DenseI64ArrayAttr::get (b.getContext(), static_offsets), {});
492
488
}
493
489
494
- void ShardingOp::build (::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
495
- FlatSymbolRefAttr mesh,
496
- ArrayRef<MeshAxesAttr> split_axes) {
497
- return build (
498
- b, odsState, mesh, MeshAxesArrayAttr::get (b.getContext (), split_axes), {},
499
- ::mlir::mesh::ReductionKindAttr::get (b.getContext(), ReductionKind::Sum),
500
- {}, {}, {}, {});
501
- }
502
-
503
490
void ShardingOp::build (::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
504
491
llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
505
492
ArrayRef<int64_t > static_halos,
506
493
ArrayRef<int64_t > static_offsets) {
507
- return build (
508
- b, odsState, FlatSymbolRefAttr::get (b.getContext (), mesh),
509
- MeshAxesArrayAttr::get (b.getContext (), split_axes), {},
510
- ::mlir::mesh::ReductionKindAttr::get (b.getContext(), ReductionKind::Sum),
511
- ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
512
- ::mlir::DenseI64ArrayAttr::get (b.getContext(), static_offsets), {});
494
+ return build (b, odsState, FlatSymbolRefAttr::get (b.getContext (), mesh),
495
+ MeshAxesArrayAttr::get (b.getContext (), split_axes),
496
+ ::mlir::DenseI64ArrayAttr::get (b.getContext(), static_halos), {},
497
+ ::mlir::DenseI64ArrayAttr::get (b.getContext(), static_offsets),
498
+ {});
513
499
}
514
500
515
501
void ShardingOp::build (
@@ -522,8 +508,7 @@ void ShardingOp::build(
522
508
dispatchIndexOpFoldResults (halo_sizes, dynamicHalos, staticHalos);
523
509
dispatchIndexOpFoldResults (sharded_dims_offsets, dynamicDims, staticDims);
524
510
return build (
525
- b, odsState, mesh, MeshAxesArrayAttr::get (b.getContext (), split_axes), {},
526
- ::mlir::mesh::ReductionKindAttr::get (b.getContext(), ReductionKind::Sum),
511
+ b, odsState, mesh, MeshAxesArrayAttr::get (b.getContext (), split_axes),
527
512
::mlir::DenseI64ArrayAttr::get (b.getContext(), staticHalos), dynamicHalos,
528
513
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
529
514
}
@@ -533,11 +518,6 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
533
518
534
519
build (b, odsState, ShardingType::get (b.getContext ()), from.getMeshAttr (),
535
520
MeshAxesArrayAttr::get (b.getContext (), from.getSplitAxes ()),
536
- from.getPartialAxes ().empty ()
537
- ? DenseI16ArrayAttr ()
538
- : b.getDenseI16ArrayAttr (from.getPartialAxes ()),
539
- ::mlir::mesh::ReductionKindAttr::get (b.getContext(),
540
- from.getPartialType()),
541
521
from.getStaticShardedDimsOffsets ().empty ()
542
522
? DenseI64ArrayAttr ()
543
523
: b.getDenseI64ArrayAttr (from.getStaticShardedDimsOffsets ()),
@@ -566,9 +546,6 @@ LogicalResult ShardingOp::verify() {
566
546
if (failed (checkMeshAxis (subAxesArray)))
567
547
return failure ();
568
548
}
569
- if (getPartialAxes ().has_value () &&
570
- failed (checkMeshAxis (getPartialAxes ().value ())))
571
- return failure ();
572
549
573
550
if (!getStaticHaloSizes ().empty () && !getStaticShardedDimsOffsets ().empty ()) {
574
551
return emitOpError (" halo sizes and shard offsets are mutually exclusive" );
@@ -710,17 +687,11 @@ void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
710
687
// MeshSharding
711
688
// ===----------------------------------------------------------------------===//
712
689
713
- bool MeshSharding::equalSplitAndPartialAxes (const MeshSharding &rhs) const {
690
+ bool MeshSharding::equalSplitAxes (const MeshSharding &rhs) const {
714
691
if (getMesh () != rhs.getMesh ()) {
715
692
return false ;
716
693
}
717
694
718
- if (getPartialAxes ().size () != rhs.getPartialAxes ().size () ||
719
- (!getPartialAxes ().empty () && getPartialType () != rhs.getPartialType ()) ||
720
- !llvm::equal (getPartialAxes (), rhs.getPartialAxes ())) {
721
- return false ;
722
- }
723
-
724
695
auto minSize = std::min (getSplitAxes ().size (), rhs.getSplitAxes ().size ());
725
696
if (!llvm::equal (llvm::make_range (getSplitAxes ().begin (),
726
697
getSplitAxes ().begin () + minSize),
@@ -768,13 +739,13 @@ bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const {
768
739
}
769
740
770
741
bool MeshSharding::operator ==(Value rhs) const {
771
- return equalSplitAndPartialAxes (rhs) && equalHaloAndShardSizes (rhs);
742
+ return equalSplitAxes (rhs) && equalHaloAndShardSizes (rhs);
772
743
}
773
744
774
745
bool MeshSharding::operator !=(Value rhs) const { return !(*this == rhs); }
775
746
776
747
bool MeshSharding::operator ==(const MeshSharding &rhs) const {
777
- return equalSplitAndPartialAxes (rhs) && equalHaloAndShardSizes (rhs);
748
+ return equalSplitAxes (rhs) && equalHaloAndShardSizes (rhs);
778
749
}
779
750
780
751
bool MeshSharding::operator !=(const MeshSharding &rhs) const {
@@ -787,30 +758,26 @@ MeshSharding::MeshSharding(Value rhs) {
787
758
auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp ());
788
759
assert (shardingOp && " expected sharding op" );
789
760
auto splitAxes = shardingOp.getSplitAxes ().getAxes ();
790
- auto partialAxes = shardingOp.getPartialAxes ().value_or (ArrayRef<MeshAxis>());
791
- // If splitAxes and partialAxes are empty, use "empty" constructor.
792
- if (splitAxes.empty () && partialAxes.empty ()) {
761
+ // If splitAxes are empty, use "empty" constructor.
762
+ if (splitAxes.empty ()) {
793
763
*this = MeshSharding (shardingOp.getMeshAttr ());
794
764
return ;
795
765
}
796
- *this = get (shardingOp.getMeshAttr (), splitAxes, partialAxes,
797
- shardingOp.getPartialType ().value_or (ReductionKind::Sum),
798
- shardingOp.getStaticHaloSizes (),
799
- shardingOp.getStaticShardedDimsOffsets (),
800
- SmallVector<Value>(shardingOp.getDynamicHaloSizes ()),
801
- SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets ()));
766
+ *this =
767
+ get (shardingOp.getMeshAttr (), splitAxes, shardingOp.getStaticHaloSizes (),
768
+ shardingOp.getStaticShardedDimsOffsets (),
769
+ SmallVector<Value>(shardingOp.getDynamicHaloSizes ()),
770
+ SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets ()));
802
771
}
803
772
804
773
MeshSharding MeshSharding::get (::mlir::FlatSymbolRefAttr mesh_,
805
774
ArrayRef<MeshAxesAttr> split_axes_,
806
- ArrayRef<MeshAxis> partial_axes_,
807
- ReductionKind partial_type_,
808
775
ArrayRef<int64_t > static_halo_sizes_,
809
776
ArrayRef<int64_t > static_sharded_dims_offsets_,
810
777
ArrayRef<Value> dynamic_halo_sizes_,
811
778
ArrayRef<Value> dynamic_sharded_dims_offsets_) {
812
779
MeshSharding res (mesh_);
813
- if (split_axes_.empty () && partial_axes_. empty () ) {
780
+ if (split_axes_.empty ()) {
814
781
return res;
815
782
}
816
783
@@ -825,8 +792,6 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
825
792
llvm::copy (src, dst.begin ());
826
793
};
827
794
828
- clone (partial_axes_, res.partial_axes );
829
- res.partial_type = partial_type_;
830
795
clone (static_halo_sizes_, res.static_halo_sizes );
831
796
clone (static_sharded_dims_offsets_, res.static_sharded_dims_offsets );
832
797
clone (dynamic_halo_sizes_, res.dynamic_halo_sizes );
0 commit comments