Skip to content

Commit 432da09

Browse files
junwhanahnGoogle-ML-Automation
authored andcommitted
Clean up the reshard API from the IFRT Proxy server
It has been six months since we switched from `Reshard` to `CopyArrays`. Per compatibility contract, it is now safe to remove the Reshard emulation code on the proxy server. PiperOrigin-RevId: 707182393
1 parent 60aeeca commit 432da09

File tree

6 files changed

+5
-122
lines changed

6 files changed

+5
-122
lines changed

xla/python/ifrt_proxy/client/rpc_helper.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,6 @@ RPC(CopyToHostBuffer, copy_to_host_buffer);
330330
RPC(IsArrayDeleted, is_array_deleted);
331331
RPC(DestructArray, destruct_array)
332332
RPC(CopyArrays, copy_arrays);
333-
RPC(Reshard, reshard);
334333
RPC(FullyReplicatedShard, fully_replicated_shard);
335334
RPC(DeleteArray, delete_array);
336335
RPC(Compile, compile);

xla/python/ifrt_proxy/client/rpc_helper.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ class RpcHelper {
112112
std::unique_ptr<CopyToHostBufferRequest> req);
113113
ResponseFuture<CopyArraysResponse> CopyArrays(
114114
std::unique_ptr<CopyArraysRequest> req);
115-
ResponseFuture<ReshardResponse> Reshard(std::unique_ptr<ReshardRequest> req);
116115
ResponseFuture<FullyReplicatedShardResponse> FullyReplicatedShard(
117116
std::unique_ptr<FullyReplicatedShardRequest> req);
118117
ResponseFuture<IsArrayDeletedResponse> IsArrayDeleted(

xla/python/ifrt_proxy/common/ifrt_service.proto

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ message IfrtRequest {
5656
disassemble_into_single_device_arrays_request = 7;
5757
DeleteArrayRequest delete_array_request = 9;
5858
CopyArraysRequest copy_arrays_request = 24;
59-
ReshardRequest reshard_request = 10 [deprecated = true];
6059
FullyReplicatedShardRequest fully_replicated_shard_request = 20;
6160
IsArrayDeletedRequest is_array_deleted_request = 11;
6261
DestructArrayRequest destruct_array_request = 12;
@@ -79,6 +78,8 @@ message IfrtRequest {
7978
GetDefaultDeviceAssignmentRequest get_default_device_assignment_request =
8079
19;
8180
}
81+
82+
reserved 10;
8283
}
8384

8485
message IfrtResponse {
@@ -103,7 +104,6 @@ message IfrtResponse {
103104
disassemble_into_single_device_arrays_response = 7;
104105
DeleteArrayResponse delete_array_response = 9;
105106
CopyArraysResponse copy_arrays_response = 24;
106-
ReshardResponse reshard_response = 10 [deprecated = true];
107107
FullyReplicatedShardResponse fully_replicated_shard_response = 20;
108108
IsArrayDeletedResponse is_array_deleted_response = 11;
109109
DestructArrayResponse destruct_array_response = 12;
@@ -127,6 +127,8 @@ message IfrtResponse {
127127
GetDefaultDeviceAssignmentResponse get_default_device_assignment_response =
128128
19;
129129
}
130+
131+
reserved 10;
130132
}
131133

132134
// Metadata of an IFRT Request.
@@ -323,15 +325,6 @@ message CopyArraysResponse {
323325
repeated fixed64 array_handles = 1;
324326
}
325327

326-
message ReshardRequest {
327-
fixed64 array_handle = 1;
328-
ShardingProto sharding = 2;
329-
proto.ArrayCopySemantics copy_semantics = 3;
330-
}
331-
message ReshardResponse {
332-
fixed64 array_handle = 1;
333-
}
334-
335328
message FullyReplicatedShardRequest {
336329
fixed64 array_handle = 1;
337330
proto.ArrayCopySemantics copy_semantics = 2;

xla/python/ifrt_proxy/common/versions.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace protocol_version {
2626
inline constexpr int kClientMin = 3;
2727

2828
// The minimum protocol_version that the current server code understands.
29-
inline constexpr int kServerMin = 1;
29+
inline constexpr int kServerMin = 3;
3030

3131
enum {
3232
// Versions kAncient are named and are only referred to by their numbers. See

xla/python/ifrt_proxy/server/ifrt_backend.cc

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,6 @@ Future<BackendInterface::Response> IfrtBackend::ProcessInternal(
342342
return Future<Response>(HandleCheckValueReadyRequest(std::move(request)));
343343
case IfrtRequest::RequestCase::kCopyArraysRequest:
344344
return Future<Response>(HandleCopyArraysRequest(std::move(request)));
345-
case IfrtRequest::RequestCase::kReshardRequest:
346-
return Future<Response>(HandleReshardRequest(std::move(request)));
347345
case IfrtRequest::RequestCase::kFullyReplicatedShardRequest:
348346
return Future<Response>(
349347
HandleFullyReplicatedShardRequest(std::move(request)));
@@ -1029,44 +1027,6 @@ absl::StatusOr<BackendInterface::Response> IfrtBackend::HandleCopyArraysRequest(
10291027
return ifrt_resp;
10301028
}
10311029

1032-
absl::StatusOr<BackendInterface::Response> IfrtBackend::HandleReshardRequest(
1033-
std::unique_ptr<IfrtRequest> request) {
1034-
const auto& reshard_request = request->reshard_request();
1035-
TF_ASSIGN_OR_RETURN(auto array, GetArray(reshard_request.array_handle()));
1036-
TF_ASSIGN_OR_RETURN(
1037-
std::shared_ptr<const Sharding> sharding,
1038-
Sharding::FromProto(
1039-
absl::bind_front(&Client::LookupDevice, client_.get()),
1040-
reshard_request.sharding()));
1041-
TF_ASSIGN_OR_RETURN(auto semantics, FromArrayCopySemanticsProto(
1042-
reshard_request.copy_semantics()));
1043-
1044-
// Emulate the old `Array::Reshard` behavior using `Client::CopyArrays`. No
1045-
// existing IFRT implementations before `Array::Reshard` was deleted actually
1046-
// supported resharding, so this should be safe.
1047-
if (!array->sharding().HasSamePartitioning(*sharding)) {
1048-
return absl::InvalidArgumentError(absl::StrCat(
1049-
"IFRT Proxy does not support resharding, but got ",
1050-
array->sharding().DebugString(), " as the original sharding and ",
1051-
sharding->DebugString(), " as the target sharding"));
1052-
}
1053-
TF_ASSIGN_OR_RETURN(
1054-
auto copied_arrays,
1055-
client_->CopyArrays(absl::MakeSpan(&array, 1), sharding->devices(),
1056-
sharding->memory_kind(), semantics));
1057-
1058-
uint64_t resharded_array_handle = handle_generator_.GenerateAtServer();
1059-
{
1060-
absl::MutexLock lock(&arrays_mutex_);
1061-
arrays_.insert({resharded_array_handle, std::move(copied_arrays[0])});
1062-
}
1063-
1064-
auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id());
1065-
ifrt_resp->mutable_reshard_response()->set_array_handle(
1066-
resharded_array_handle);
1067-
return ifrt_resp;
1068-
}
1069-
10701030
absl::StatusOr<BackendInterface::Response>
10711031
IfrtBackend::HandleFullyReplicatedShardRequest(
10721032
std::unique_ptr<IfrtRequest> request) {

xla/python/ifrt_proxy/server/ifrt_backend_test.cc

Lines changed: 0 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -970,74 +970,6 @@ TEST_P(IfrtBackendHandlerTest, CopyArrays) {
970970
SizeIs(copied_arrays.size()));
971971
}
972972

973-
TEST_P(IfrtBackendHandlerTest, ReshardSuccess) {
974-
auto src_mock_array = tsl::MakeRef<xla::ifrt::MockArray>();
975-
TF_ASSERT_OK_AND_ASSIGN(auto* device,
976-
mock_client_->LookupDevice(DeviceId(0)));
977-
auto src_sharding = SingleDeviceSharding::Create(device, MemoryKind());
978-
ON_CALL(*src_mock_array, sharding()).WillByDefault(ReturnRef(*src_sharding));
979-
TF_ASSERT_OK_AND_ASSIGN(auto src_array_handle,
980-
MakeTestArray(std::move(src_mock_array)));
981-
982-
auto copied_mock_array = tsl::MakeRef<xla::ifrt::MockArray>();
983-
EXPECT_CALL(*mock_client_, CopyArrays(_, _, _, _))
984-
.WillOnce(Return(std::vector<tsl::RCReference<xla::ifrt::Array>>(
985-
{copied_mock_array})));
986-
987-
auto ifrt_request = NewIfrtRequest(NewOpId());
988-
auto* reshard_request = ifrt_request->mutable_reshard_request();
989-
reshard_request->set_array_handle(src_array_handle);
990-
reshard_request->set_copy_semantics(proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY);
991-
TF_ASSERT_OK_AND_ASSIGN(auto* new_device,
992-
mock_client_->LookupDevice(DeviceId(1)));
993-
TF_ASSERT_OK_AND_ASSIGN(
994-
*ifrt_request->mutable_reshard_request()->mutable_sharding(),
995-
SingleDeviceSharding::Create(new_device, MemoryKind())->ToProto());
996-
997-
TF_ASSERT_OK_AND_ASSIGN(auto response, CallBackend(std::move(ifrt_request)));
998-
999-
EXPECT_THAT(tsl::StatusFromProto(response->response_metadata().status()),
1000-
IsOk());
1001-
EXPECT_NE(response->reshard_response().array_handle(), 0);
1002-
}
1003-
1004-
TEST_P(IfrtBackendHandlerTest, ReshardFailsWhenTheBackendFails) {
1005-
auto mock_array = tsl::MakeRef<xla::ifrt::MockArray>();
1006-
TF_ASSERT_OK_AND_ASSIGN(auto* device,
1007-
mock_client_->LookupDevice(DeviceId(1)));
1008-
auto sharding = SingleDeviceSharding::Create(device, MemoryKind());
1009-
ON_CALL(*mock_array, sharding()).WillByDefault(ReturnRef(*sharding));
1010-
TF_ASSERT_OK_AND_ASSIGN(auto array_handle,
1011-
MakeTestArray(std::move(mock_array)));
1012-
1013-
EXPECT_CALL(*mock_client_, CopyArrays(_, _, _, _))
1014-
.WillOnce(Return(absl::UnknownError("injected error")));
1015-
1016-
auto ifrt_request = NewIfrtRequest(NewOpId());
1017-
auto* reshard_request = ifrt_request->mutable_reshard_request();
1018-
reshard_request->set_array_handle(array_handle);
1019-
reshard_request->set_copy_semantics(proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY);
1020-
TF_ASSERT_OK_AND_ASSIGN(auto* new_device,
1021-
mock_client_->LookupDevice(DeviceId(1)));
1022-
TF_ASSERT_OK_AND_ASSIGN(
1023-
*ifrt_request->mutable_reshard_request()->mutable_sharding(),
1024-
SingleDeviceSharding::Create(new_device, MemoryKind())->ToProto());
1025-
1026-
EXPECT_THAT(CallBackend(std::move(ifrt_request)),
1027-
StatusIs(absl::StatusCode::kUnknown, StrEq("injected error")));
1028-
}
1029-
1030-
TEST_P(IfrtBackendHandlerTest, ReshardFailsWithNonExistentArrayHandle) {
1031-
auto ifrt_request = NewIfrtRequest(NewOpId());
1032-
auto* reshard_request = ifrt_request->mutable_reshard_request();
1033-
reshard_request->set_array_handle(0);
1034-
reshard_request->set_copy_semantics(proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY);
1035-
reshard_request->mutable_sharding();
1036-
1037-
EXPECT_THAT(CallBackend(std::move(ifrt_request)),
1038-
StatusIs(absl::StatusCode::kNotFound));
1039-
}
1040-
1041973
TEST_P(IfrtBackendHandlerTest, FullyReplicatedShardSuccess) {
1042974
auto fully_replicated_mock_array = tsl::MakeRef<xla::ifrt::MockArray>();
1043975
auto resultant_array = tsl::MakeRef<xla::ifrt::MockArray>();

0 commit comments

Comments
 (0)