diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 39b2341115..e2610cc4e9 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -2147,24 +2147,34 @@ REACTANT_ABI int ifrt_client_addressable_device_count(ifrt::Client *client) { } REACTANT_ABI void ifrt_client_devices(ifrt::Client *client, - ifrt::Device **out_devices) { + ifrt::Device **out_devices, + uint64_t num_devices) { auto span = client->devices(); + if (span.size() != num_devices) + ReactantThrowError("Incorrect number of devices provided"); + assert(span.size() == num_devices); for (int i = 0; i < span.size(); i++) { out_devices[i] = span[i]; } } REACTANT_ABI void ifrt_client_addressable_devices(ifrt::Client *client, - ifrt::Device **out_devices) { + ifrt::Device **out_devices, + uint64_t num_devices) { auto span = client->addressable_devices(); + if (span.size() != num_devices) + ReactantThrowError("Incorrect number of devices provided"); for (int i = 0; i < span.size(); i++) { out_devices[i] = span[i]; } } REACTANT_ABI void ifrt_client_all_devices(ifrt::Client *client, - ifrt::Device **out_devices) { + ifrt::Device **out_devices, + uint64_t num_devices) { auto span = client->GetAllDevices(); + if (span.size() != num_devices) + ReactantThrowError("Incorrect number of devices provided"); for (int i = 0; i < span.size(); i++) { out_devices[i] = span[i]; }