Skip to content

Commit 4b4165b

Browse files
junwhanahnGoogle-ML-Automation
authored andcommitted
Return std::nullopt from xla::ifrt::LoadedExecutable::devices() for portable executables
Portable executables are not bound to any specific device list, so it makes sense to be explicit about it rather than letting each implementation return an unspecified value. PiperOrigin-RevId: 837188096
1 parent ca8d1ec commit 4b4165b

File tree

12 files changed

+40
-23
lines changed

12 files changed

+40
-23
lines changed

xla/backends/cpu/nanort/ifrt_client.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1009,7 +1009,9 @@ class NanoExecutable final
10091009
return client_->addressable_devices();
10101010
}
10111011

1012-
const ifrt::DeviceListRef& devices() const override { return devices_; }
1012+
std::optional<ifrt::DeviceListRef> devices() const override {
1013+
return devices_;
1014+
}
10131015

10141016
static char ID; // NOLINT
10151017

xla/backends/cpu/nanort/ifrt_client_test.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,8 @@ int main(int argc, char** argv) {
337337
"*LoadedExecutableImplTest.GetHloModules*:"
338338
"*LoadedExecutableImplTest.ProgramText*:"
339339
"*LoadedExecutableImplTest.Analysis*:"
340+
// NanoRT does not support portable execution.
341+
"*LoadedExecutableImplTest.CompileAndExecutePortable*:"
340342
// Serialization is not implemented.
341343
"*SerializeAndLoad*";
342344
xla::ifrt::test_util::SetTestFilterIfNotUserSpecified(kFilter);

xla/python/ifrt/executable.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,9 @@ class LoadedExecutable
279279
std::optional<DeviceListRef> devices) = 0;
280280

281281
// Returns the list of devices where the executable has been compiled and
282-
// loaded onto.
283-
virtual const DeviceListRef& devices() const = 0;
282+
// loaded onto. Returns `std::nullopt` if the executable is not bound to a
283+
// particular device list, e.g., portable executables.
284+
virtual std::optional<DeviceListRef> devices() const = 0;
284285

285286
// The following APIs are taken from xla::PjRtLoadedExecutable for fast
286287
// prototyping.

xla/python/ifrt/mock.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -317,13 +317,6 @@ class MockExecutable : public llvm::RTTIExtends<MockExecutable, Executable> {
317317
class MockLoadedExecutable
318318
: public llvm::RTTIExtends<MockLoadedExecutable, LoadedExecutable> {
319319
public:
320-
MockLoadedExecutable() {
321-
static absl::NoDestructor<DeviceListRef> kEmptyDeviceList(
322-
BasicDeviceList::Create({}));
323-
ON_CALL(*this, devices())
324-
.WillByDefault(testing::ReturnRef(*kEmptyDeviceList));
325-
}
326-
327320
MOCK_METHOD(Client*, client, (), (const, final));
328321
MOCK_METHOD(absl::string_view, name, (), (const, final));
329322
MOCK_METHOD(absl::StatusOr<std::optional<std::string>>, Fingerprint, (),
@@ -363,7 +356,7 @@ class MockLoadedExecutable
363356
(final));
364357
MOCK_METHOD(absl::Span<Device* const>, addressable_devices, (),
365358
(const, final));
366-
MOCK_METHOD(const DeviceListRef&, devices, (), (const, final));
359+
MOCK_METHOD(std::optional<DeviceListRef>, devices, (), (const, final));
367360

368361
static char ID; // NOLINT
369362
};

xla/python/ifrt_proxy/client/compiler.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,10 @@ absl::StatusOr<xla::ifrt::LoadedExecutableRef> Compiler::CompileAndLoad(
177177
devices.push_back(device);
178178
}
179179
}
180-
TF_ASSIGN_OR_RETURN(DeviceListRef device_list,
181-
client_->MakeDeviceList(devices));
180+
std::optional<DeviceListRef> device_list;
181+
if (!devices.empty()) {
182+
TF_ASSIGN_OR_RETURN(device_list, client_->MakeDeviceList(devices));
183+
}
182184

183185
return std::make_unique<LoadedExecutable>(
184186
client_, rpc_helper_, response->loaded_executable_handle(),

xla/python/ifrt_proxy/client/executable.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,8 @@ class LoadedExecutable::OutputSpecCache {
337337

338338
LoadedExecutable::LoadedExecutable(
339339
xla::ifrt::Client* client, std::shared_ptr<RpcHelper> rpc_helper,
340-
uint64_t handle, std::string name, int num_devices, DeviceListRef devices,
340+
uint64_t handle, std::string name, int num_devices,
341+
std::optional<DeviceListRef> devices,
341342
std::vector<xla::ifrt::Device*> addressable_devices,
342343
absl::StatusOr<std::optional<std::string>> fingerprint,
343344
tsl::Future<> ready_future,
@@ -843,7 +844,9 @@ LoadedExecutable::Execute(absl::Span<xla::ifrt::ArrayRef> args,
843844
return result;
844845
}
845846

846-
const DeviceListRef& LoadedExecutable::devices() const { return devices_; }
847+
std::optional<DeviceListRef> LoadedExecutable::devices() const {
848+
return devices_;
849+
}
847850

848851
absl::Span<xla::ifrt::Device* const> LoadedExecutable::addressable_devices()
849852
const {

xla/python/ifrt_proxy/client/executable.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ class LoadedExecutable final
5959
public:
6060
LoadedExecutable(xla::ifrt::Client* client,
6161
std::shared_ptr<RpcHelper> rpc_helper, uint64_t handle,
62-
std::string name, int num_devices, DeviceListRef devices,
62+
std::string name, int num_devices,
63+
std::optional<DeviceListRef> devices,
6364
std::vector<xla::ifrt::Device*> addressable_devices,
6465
absl::StatusOr<std::optional<std::string>> fingerprint,
6566
tsl::Future<> ready_future,
@@ -111,7 +112,7 @@ class LoadedExecutable final
111112
absl::Span<xla::ifrt::ArrayRef> args, const ExecuteOptions& options,
112113
std::optional<xla::ifrt::DeviceListRef> devices) override;
113114

114-
const DeviceListRef& devices() const override;
115+
std::optional<DeviceListRef> devices() const override;
115116
absl::Span<xla::ifrt::Device* const> addressable_devices() const override;
116117

117118
static char ID; // NOLINT
@@ -148,7 +149,7 @@ class LoadedExecutable final
148149
const uint64_t handle_;
149150
const std::string name_;
150151
const int num_devices_;
151-
const DeviceListRef devices_;
152+
const std::optional<DeviceListRef> devices_;
152153
const std::vector<xla::ifrt::Device*> addressable_devices_;
153154
const absl::StatusOr<std::optional<std::string>> fingerprint_;
154155
const tsl::Future<> ready_future_;

xla/python/ifrt_proxy/server/ifrt_backend.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,8 +1456,12 @@ tsl::Future<BackendInterface::Response> IfrtBackend::HandleCompileRequest(
14561456
for (const auto* device : executable->addressable_devices()) {
14571457
compile_resp->add_addressable_device_ids(device->Id().value());
14581458
}
1459-
for (const auto* device : executable->devices()->devices()) {
1460-
compile_resp->add_device_ids(device->Id().value());
1459+
if (std::optional<xla::ifrt::DeviceListRef> device_list =
1460+
executable->devices();
1461+
device_list.has_value()) {
1462+
for (const auto* device : (*device_list)->devices()) {
1463+
compile_resp->add_device_ids(device->Id().value());
1464+
}
14611465
}
14621466
// TODO(b/282757875): Consider making fingerprint calculation asynchronous
14631467
// if it is expected to take long.

xla/python/ifrt_proxy/server/ifrt_backend_test.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1157,7 +1157,8 @@ TEST_P(IfrtBackendHandlerTest, CompileSuccess) {
11571157
auto executable = std::make_unique<MockLoadedExecutable>();
11581158
EXPECT_CALL(*executable, name()).WillOnce(Return("executable_name"));
11591159
EXPECT_CALL(*executable, num_devices()).WillOnce(Return(4));
1160-
EXPECT_CALL(*executable, devices()).WillOnce(ReturnRef(device_list));
1160+
EXPECT_CALL(*executable, devices())
1161+
.WillOnce(Return(std::make_optional(device_list)));
11611162
EXPECT_CALL(*executable, addressable_devices())
11621163
.WillOnce(Return(absl::MakeSpan(addressable_devices)));
11631164
EXPECT_CALL(*executable, Fingerprint()).WillOnce(Return("fingerprint"));

xla/python/pjrt_ifrt/pjrt_executable.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,14 @@ class PjRtLoadedExecutable final
316316
absl::Span<ArrayRef> args, const ExecuteOptions& options,
317317
std::optional<DeviceListRef> devices) override;
318318

319-
const DeviceListRef& devices() const override { return devices_; }
319+
std::optional<DeviceListRef> devices() const override {
320+
if (pjrt_loaded_executable_->addressable_devices().empty()) {
321+
// Portable executable.
322+
return std::nullopt;
323+
} else {
324+
return devices_;
325+
}
326+
}
320327

321328
absl::Span<Device* const> addressable_devices() const override {
322329
DCHECK(this);

0 commit comments

Comments
 (0)