@@ -454,16 +454,14 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
454454 ArrayRef<MeshAxesAttr> split_axes,
455455 ArrayRef<MeshAxis> partial_axes,
456456 mesh::ReductionKind partial_type,
457- ArrayRef<int64_t > static_halo_sizes ,
458- ArrayRef<int64_t > static_sharded_dims_offsets ) {
457+ ArrayRef<int64_t > static_halos ,
458+ ArrayRef<int64_t > static_offsets ) {
459459 return build (
460460 b, odsState, mesh, MeshAxesArrayAttr::get (b.getContext (), split_axes),
461461 ::mlir::DenseI16ArrayAttr::get (b.getContext(), partial_axes),
462462 ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
463- ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
464- ::mlir::DenseI64ArrayAttr::get (b.getContext(),
465- static_sharded_dims_offsets),
466- {});
463+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
464+ ::mlir::DenseI64ArrayAttr::get (b.getContext(), static_offsets), {});
467465}
468466
469467void ShardingOp::build (::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
@@ -475,6 +473,18 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
475473 {}, {}, {}, {});
476474}
477475
476+ void ShardingOp::build (::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
477+ llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
478+ ArrayRef<int64_t > static_halos,
479+ ArrayRef<int64_t > static_offsets) {
480+ return build (
481+ b, odsState, FlatSymbolRefAttr::get (b.getContext (), mesh),
482+ MeshAxesArrayAttr::get (b.getContext (), split_axes), {},
483+ ::mlir::mesh::ReductionKindAttr::get (b.getContext(), ReductionKind::Sum),
484+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
485+ ::mlir::DenseI64ArrayAttr::get (b.getContext(), static_offsets), {});
486+ }
487+
478488void ShardingOp::build (
479489 ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
480490 FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
0 commit comments