Skip to content

Commit b6c123d

Browse files
akuegelGoogle-ML-Automation
authored andcommitted
Adapt Stablehlo to recent upstream Tosa changes.
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 FUTURE_COPYBARA_INTEGRATE_REVIEW=#23793 from Tixxx:tixxx/cp_sync 3a8a643 PiperOrigin-RevId: 738783395
1 parent 469329e commit b6c123d

File tree

18 files changed

+750
-127
lines changed

18 files changed

+750
-127
lines changed

third_party/stablehlo/temporary.patch

Lines changed: 341 additions & 84 deletions
Large diffs are not rendered by default.

xla/backends/gpu/runtime/collective_permute_thunk.cc

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@ absl::Status CollectivePermuteStartThunk::Initialize(
190190
params.executor->CreateEvent());
191191
receiver_barrier_events_.emplace(current_id, std::move(receiver_event));
192192
}
193+
if (sender_barrier_events_.find(current_id) ==
194+
sender_barrier_events_.end()) {
195+
TF_ASSIGN_OR_RETURN(auto sender_event, params.executor->CreateEvent());
196+
sender_barrier_events_.emplace(current_id, std::move(sender_event));
197+
}
193198
}
194199
TF_ASSIGN_OR_RETURN(
195200
std::vector<DeviceBufferPair> device_buffers,
@@ -273,7 +278,7 @@ absl::Status CollectivePermuteStartThunk::RunCollective(
273278
config().replica_groups, config().group_mode));
274279

275280
auto rendezvous_name = absl::StrFormat(
276-
"rendezvous of collective-permute; run_id=%d; op id:%d; "
281+
"rendezvous before calling collective-permute; run_id=%d; op id:%d; "
277282
"num_local_participants:%d",
278283
params.collective_params->run_id.ToInt(), config_.config.op_id,
279284
num_local_participants);
@@ -293,9 +298,48 @@ absl::Status CollectivePermuteStartThunk::RunCollective(
293298
}
294299
}
295300

296-
return ::xla::gpu::RunCollectivePermute(
301+
auto status = ::xla::gpu::RunCollectivePermute(
297302
collectives, source_target, device_buffers, stream, comm_handle.comm,
298303
device_string, current_id, use_memcpy, recv_ptr_map_);
304+
305+
if (use_memcpy) {
306+
std::optional<int64_t> source_id = source_target.source;
307+
std::optional<int64_t> target_id = source_target.target;
308+
// After the memcpy p2p is dispatched, the receiver needs to
309+
// wait for the sender's event before proceeding to ensure
310+
// data has been copied.
311+
if (target_id) {
312+
absl::MutexLock lock(&barrier_mutex_);
313+
auto sender_event = sender_barrier_events_.find(current_id);
314+
TF_RETURN_IF_ERROR(stream.RecordEvent(sender_event->second.get()));
315+
}
316+
TF_ASSIGN_OR_RETURN(
317+
size_t num_local_participants,
318+
GetNumLocalParticipants(*params.collective_params,
319+
config().replica_groups, config().group_mode));
320+
321+
auto rendezvous_name = absl::StrFormat(
322+
"rendezvous after calling collective-permute; run_id=%d; op id:%d; "
323+
"num_local_participants:%d",
324+
params.collective_params->run_id.ToInt(), config_.config.op_id,
325+
num_local_participants);
326+
auto rendezvous_key = CallRendezvousKey{params.collective_params->run_id};
327+
328+
// Perform a rendezvous to make sure all senders have their events
329+
// recorded.
330+
Rendezvous(rendezvous_name, rendezvous_key, num_local_participants,
331+
/*warn_stuck_timeout=*/absl::Seconds(20),
332+
/*terminate_timeout=*/absl::Seconds(40));
333+
334+
// For receiving side, wait for the recorded event from the sending side.
335+
if (source_id) {
336+
absl::MutexLock lock(&barrier_mutex_);
337+
auto sender_event = sender_barrier_events_.find(*source_id);
338+
TF_RETURN_IF_ERROR(stream.WaitFor(sender_event->second.get()));
339+
}
340+
}
341+
342+
return status;
299343
}
300344

301345
absl::Status RunCollectivePermute(

xla/backends/gpu/runtime/collective_permute_thunk.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ class CollectivePermuteStartThunk : public CollectiveThunk {
122122
absl::Mutex barrier_mutex_;
123123
absl::flat_hash_map<int64_t, std::unique_ptr<se::Event>>
124124
receiver_barrier_events_;
125+
absl::flat_hash_map<int64_t, std::unique_ptr<se::Event>>
126+
sender_barrier_events_;
127+
125128
bool p2p_memcpy_enabled_ = false;
126129
int64_t device_count_;
127130
};

xla/python/ifrt_proxy/client/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ cc_library(
237237
":global_flags",
238238
":rpc_helper",
239239
"//xla:status_macros",
240+
"//xla/pjrt:pjrt_layout",
240241
"//xla/python/ifrt",
241242
"//xla/python/ifrt:client_impl_util",
242243
"//xla/python/ifrt:sharding_serdes",
@@ -258,6 +259,7 @@ cc_library(
258259
"@com_google_absl//absl/strings",
259260
"@com_google_absl//absl/strings:cord",
260261
"@com_google_absl//absl/strings:str_format",
262+
"@com_google_absl//absl/synchronization",
261263
"@com_google_absl//absl/types:span",
262264
"@llvm-project//llvm:Support",
263265
"@tsl//tsl/profiler/lib:traceme",
@@ -400,6 +402,7 @@ cc_library(
400402
"@com_google_absl//absl/base:nullability",
401403
"@com_google_absl//absl/cleanup",
402404
"@com_google_absl//absl/container:flat_hash_map",
405+
"@com_google_absl//absl/container:flat_hash_set",
403406
"@com_google_absl//absl/container:node_hash_set",
404407
"@com_google_absl//absl/log",
405408
"@com_google_absl//absl/log:check",

xla/python/ifrt_proxy/client/array.cc

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "absl/strings/str_join.h"
3737
#include "absl/strings/string_view.h"
3838
#include "absl/strings/substitute.h"
39+
#include "absl/synchronization/mutex.h"
3940
#include "absl/types/span.h"
4041
#include "llvm/Support/Casting.h"
4142
#include "xla/python/ifrt/array.h"
@@ -393,6 +394,13 @@ Future<> Array::GetReadyFuture() const {
393394
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
394395
"IfrtProxyEntrypointArrayGetReadyFuture");
395396

397+
{
398+
absl::MutexLock lock(&mu_);
399+
if (deleted_ == DeletionState::kDeleted) {
400+
return Future<>(absl::InvalidArgumentError("Already deleted array."));
401+
}
402+
}
403+
396404
auto req = std::make_unique<CheckValueReadyRequest>();
397405
req->add_value_handles(handle_.handle);
398406

@@ -405,6 +413,10 @@ Future<> Array::GetReadyFuture() const {
405413
}
406414

407415
Future<> Array::Delete() {
416+
{
417+
absl::MutexLock lock(&mu_);
418+
deleted_ = DeletionState::kDeleted;
419+
}
408420
if (rpc_helper_->version().protocol_version() >= 5) {
409421
rpc_helper_->Batch(RpcHelper::kDeleteArray, handle_);
410422
return Future<>(absl::OkStatus());
@@ -429,6 +441,15 @@ Future<> Array::Delete() {
429441
bool Array::IsDeleted() const {
430442
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
431443
"IfrtProxyEntrypointIsDeleted");
444+
{
445+
absl::MutexLock lock(&mu_);
446+
if (deleted_ == DeletionState::kDeleted) {
447+
return true;
448+
}
449+
if (deleted_ == DeletionState::kAlive) {
450+
return false;
451+
}
452+
}
432453
if (GetGlobalClientFlags()->array_is_deleted_hack) {
433454
return false;
434455
}
@@ -438,6 +459,12 @@ bool Array::IsDeleted() const {
438459
absl::StatusOr<std::shared_ptr<IsArrayDeletedResponse>> response =
439460
rpc_helper_->IsArrayDeleted(std::move(req)).Await();
440461
if (response.ok()) {
462+
absl::MutexLock lock(&mu_);
463+
if ((*response)->deleted()) {
464+
deleted_ = DeletionState::kDeleted;
465+
} else {
466+
deleted_ = DeletionState::kAlive;
467+
}
441468
return (*response)->deleted();
442469
} else {
443470
LOG(ERROR) << "Internal error from proxy server during Array::IsDeleted(): "
@@ -486,7 +513,9 @@ Array::AssembleArrayFromSingleDeviceArrays(
486513
"not a xla::ifrt::proxy::Array.",
487514
rcref.get()));
488515
}
489-
req->add_single_device_array_handles(array->handle_.handle);
516+
TF_ASSIGN_OR_RETURN(ArrayHandle handle,
517+
array->GetHandle(array_copy_semantics));
518+
req->add_single_device_array_handles(handle.handle);
490519
}
491520

492521
ArrayHandle result_handle;
@@ -569,8 +598,8 @@ Array::RemapArrays(xla::ifrt::Client* client,
569598
*arrays[i]->sharding().devices(),
570599
arrays[i]->sharding().memory_kind()));
571600
}
572-
573-
req->add_array_handles(array->handle_.handle);
601+
TF_ASSIGN_OR_RETURN(ArrayHandle handle, array->GetHandle(semantics));
602+
req->add_array_handles(handle.handle);
574603
}
575604

576605
std::vector<ArrayHandle> result_handles;
@@ -617,7 +646,8 @@ Array::DisassembleIntoSingleDeviceArrays(
617646
"version < 8");
618647
}
619648
auto req = std::make_unique<DisassembleIntoSingleDeviceArraysRequest>();
620-
req->set_array_handle(handle_.handle);
649+
TF_ASSIGN_OR_RETURN(ArrayHandle handle, GetHandle(array_copy_semantics));
650+
req->set_array_handle(handle.handle);
621651
req->set_copy_semantics(ToArrayCopySemanticsProto(array_copy_semantics));
622652
req->set_single_device_shard_semantics(
623653
ToSingleDeviceShardSemanticsProto(single_device_shard_semantics));
@@ -665,7 +695,8 @@ absl::StatusOr<tsl::RCReference<xla::ifrt::Array>> Array::FullyReplicatedShard(
665695
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
666696
"IfrtProxyEntrypointFullyReplicatedShard");
667697
auto req = std::make_unique<FullyReplicatedShardRequest>();
668-
req->set_array_handle(handle_.handle);
698+
TF_ASSIGN_OR_RETURN(ArrayHandle handle, GetHandle(semantics));
699+
req->set_array_handle(handle.handle);
669700
req->set_copy_semantics(ToArrayCopySemanticsProto(semantics));
670701

671702
ArrayHandle result_handle;
@@ -706,7 +737,11 @@ Future<> Array::CopyToStringHostBuffer(
706737
"String arrays are not supported in ifrt-proxy version < 9"));
707738
}
708739
auto req = std::make_unique<CopyToHostBufferRequest>();
709-
req->set_array_handle(handle_.handle);
740+
absl::StatusOr<ArrayHandle> handle = GetHandle(semantics);
741+
if (!handle.ok()) {
742+
return Future<>(handle.status());
743+
}
744+
req->set_array_handle(handle->handle);
710745
if (byte_strides.has_value()) {
711746
return Future<>(absl::InvalidArgumentError(
712747
"Byte strides are not supported for string arrays."));
@@ -768,7 +803,11 @@ Future<> Array::CopyToHostBuffer(
768803
}
769804

770805
auto req = std::make_unique<CopyToHostBufferRequest>();
771-
req->set_array_handle(handle_.handle);
806+
absl::StatusOr<ArrayHandle> handle = GetHandle(semantics);
807+
if (!handle.ok()) {
808+
return Future<>(handle.status());
809+
}
810+
req->set_array_handle(handle->handle);
772811
if (byte_strides.has_value()) {
773812
*req->mutable_byte_strides() = ToByteStridesProto(*byte_strides);
774813
}
@@ -829,8 +868,23 @@ Future<> Array::CopyToHostBuffer(
829868
xla::ifrt::Client* Array::client() const { return client_; }
830869

831870
std::string Array::DebugString() const {
832-
return absl::Substitute("proxy::Array, this=$0, handle=$1", this,
833-
handle_.handle);
871+
std::string is_deleted;
872+
{
873+
absl::MutexLock l(&mu_);
874+
switch (deleted_) {
875+
case DeletionState::kUnknown:
876+
is_deleted = "unknown";
877+
break;
878+
case DeletionState::kDeleted:
879+
is_deleted = "true";
880+
break;
881+
case DeletionState::kAlive:
882+
is_deleted = "false";
883+
break;
884+
}
885+
}
886+
return absl::Substitute("proxy::Array, this=$0, handle=$1, deleted=$2", this,
887+
handle_.handle, is_deleted);
834888
}
835889

836890
} // namespace proxy

xla/python/ifrt_proxy/client/array.h

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,13 @@
2626
#include <vector>
2727

2828
#include "absl/base/attributes.h"
29+
#include "absl/base/thread_annotations.h"
2930
#include "absl/status/status.h"
3031
#include "absl/status/statusor.h"
32+
#include "absl/synchronization/mutex.h"
3133
#include "absl/types/span.h"
3234
#include "llvm/Support/ExtensibleRTTI.h"
35+
#include "xla/pjrt/pjrt_layout.h"
3336
#include "xla/python/ifrt/array.h"
3437
#include "xla/python/ifrt/client.h"
3538
#include "xla/python/ifrt/dtype.h"
@@ -106,7 +109,31 @@ class Array final : public llvm::RTTIExtends<Array, xla::ifrt::Array> {
106109

107110
~Array() override { Destruct(rpc_helper_.get(), handle_); }
108111

109-
ArrayHandle handle() const { return handle_; }
112+
absl::StatusOr<ArrayHandle> GetHandle(ArrayCopySemantics semantics) {
113+
absl::MutexLock l(&mu_);
114+
if (deleted_ == DeletionState::kDeleted) {
115+
return absl::InvalidArgumentError("Array already deleted.");
116+
}
117+
if (semantics == ArrayCopySemantics::kDonateInput) {
118+
deleted_ = DeletionState::kDeleted;
119+
}
120+
return handle_;
121+
}
122+
123+
// Fetches the ArrayHandle when the ArrayCopySemantics (i.e., whether the
124+
// array is meant to be donated or copied) is not known.
125+
//
126+
// Calling this function may cause `IsDelete()` calls to result in a
127+
// synchronous RPC to the proxy-server. To avoid such performance overhead,
128+
// prefer using `GetHandle(semantics)` whenever the semantics are known.
129+
absl::StatusOr<ArrayHandle> GetHandleUnknownIfBeingDonated() {
130+
absl::MutexLock l(&mu_);
131+
if (deleted_ == DeletionState::kDeleted) {
132+
return absl::InvalidArgumentError("Array already deleted.");
133+
}
134+
deleted_ = DeletionState::kUnknown;
135+
return handle_;
136+
}
110137

111138
xla::ifrt::Client* client() const override;
112139
Future<> GetReadyFuture() const override;
@@ -158,7 +185,17 @@ class Array final : public llvm::RTTIExtends<Array, xla::ifrt::Array> {
158185
const DType dtype_;
159186
const Shape shape_;
160187
const std::shared_ptr<const Sharding> sharding_;
161-
const ArrayHandle handle_;
188+
189+
const ArrayHandle handle_
190+
ABSL_DEPRECATED("Use GetHandle() function instead.");
191+
192+
mutable absl::Mutex mu_;
193+
enum class DeletionState {
194+
kUnknown, // Need to ask the proxy-server whether the array is deleted.
195+
kDeleted, // IsDeleted() will return true.
196+
kAlive // IsDeleted() will return false.
197+
};
198+
mutable DeletionState deleted_ ABSL_GUARDED_BY(mu_) = DeletionState::kAlive;
162199
};
163200

164201
} // namespace proxy

xla/python/ifrt_proxy/client/client.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,9 @@ Client::CopyArrays(absl::Span<tsl::RCReference<xla::ifrt::Array>> arrays,
277277
for (const auto& array : arrays) {
278278
if (auto* proxy_array =
279279
llvm::dyn_cast<xla::ifrt::proxy::Array>(array.get())) {
280-
req->add_array_handles(proxy_array->handle().handle);
280+
TF_ASSIGN_OR_RETURN(ArrayHandle handle,
281+
proxy_array->GetHandle(semantics));
282+
req->add_array_handles(handle.handle);
281283
} else {
282284
return absl::InvalidArgumentError(
283285
"CopyArrays only supports arrays created via IFRT Proxy client");
@@ -351,7 +353,13 @@ xla::ifrt::Future<> Client::GetReadyFuture(
351353
// type, but this may be extended later to other types such as Tuples.
352354
if (auto proxy_array =
353355
llvm::dyn_cast<xla::ifrt::proxy::Array>(value.get())) {
354-
req->add_value_handles(proxy_array->handle().handle);
356+
absl::StatusOr<ArrayHandle> handle =
357+
proxy_array->GetHandle(ArrayCopySemantics::kAlwaysCopy);
358+
if (!handle.ok()) {
359+
futures.push_back(Future<>(handle.status()));
360+
} else {
361+
req->add_value_handles(handle->handle);
362+
}
355363
} else {
356364
futures.push_back(value->GetReadyFuture());
357365
}

0 commit comments

Comments
 (0)