|
8 | 8 | #include "xla/pjrt/pjrt_c_api_client.h"
|
9 | 9 | #include "xla/pjrt/pjrt_compiler.h"
|
10 | 10 | #include "xla/pjrt/tfrt_cpu_pjrt_client.h"
|
| 11 | +#include "xla/shape_util.h" |
11 | 12 |
|
12 | 13 | namespace exla {
|
13 | 14 |
|
@@ -61,10 +62,10 @@ xla::StatusOr<std::unique_ptr<xla::PjRtBuffer>> PjRtBufferFromBinary(xla::PjRtCl
|
61 | 62 | return xla::InvalidArgument("Expected buffer to be binary.");
|
62 | 63 | }
|
63 | 64 |
|
64 |
| - xla::PjRtClient::HostBufferSemantics semantics = xla::PjRtClient::HostBufferSemantics::kZeroCopy; |
| 65 | + xla::PjRtClient::HostBufferSemantics semantics = xla::PjRtClient::HostBufferSemantics::kImmutableZeroCopy; |
65 | 66 | std::function<void()> on_done_with_host_buffer = [copy_env]() { enif_free_env(copy_env); };
|
66 | 67 |
|
67 |
| - EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client->LookupDevice(device_id)); |
| 68 | + EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client->LookupDevice(xla::PjRtGlobalDeviceId(device_id))); |
68 | 69 | EXLA_ASSIGN_OR_RETURN(auto buffer, client->BufferFromHostBuffer(
|
69 | 70 | binary.data, shape.element_type(), shape.dimensions(), std::nullopt, semantics, on_done_with_host_buffer, device));
|
70 | 71 |
|
@@ -292,7 +293,7 @@ xla::StatusOr<ERL_NIF_TERM> ExlaExecutable::Run(ErlNifEnv* env,
|
292 | 293 | // executable, meaning we need to find the device corresponding to the specific device
|
293 | 294 | // id and execute on that device, we've already guaranteed this executable only has 1
|
294 | 295 | // replica
|
295 |
| - EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client_->client()->LookupDevice(device_id)); |
| 296 | + EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client_->client()->LookupDevice(xla::PjRtGlobalDeviceId(device_id))); |
296 | 297 | // because this is a portable executable, it only has 1 replica and so we only need
|
297 | 298 | // to get the arguments at the first position of the input buffers
|
298 | 299 | std::vector<xla::PjRtBuffer*> portable_args = input_buffers.at(0);
|
@@ -390,30 +391,49 @@ xla::StatusOr<ExlaExecutable*> ExlaClient::Compile(const mlir::OwningOpRef<mlir:
|
390 | 391 | }
|
391 | 392 |
|
392 | 393 | xla::Status ExlaClient::TransferToInfeed(ErlNifEnv* env,
|
393 |
| - ERL_NIF_TERM data, |
394 |
| - const xla::Shape& shape, |
| 394 | + std::vector<ErlNifBinary> buffer_bins, |
| 395 | + std::vector<xla::Shape> shapes, |
395 | 396 | int device_id) {
|
396 |
| - // Fast path to avoid any traversal when not sending Tuples |
397 |
| - ERL_NIF_TERM head, tail; |
398 |
| - if (!enif_get_list_cell(env, data, &head, &tail)) { |
399 |
| - return xla::InvalidArgument("infeed operation expects a list of binaries"); |
400 |
| - } |
| 397 | + std::vector<const char*> buf_ptrs; |
| 398 | + buf_ptrs.reserve(buffer_bins.size()); |
401 | 399 |
|
402 |
| - ErlNifBinary binary; |
403 |
| - if (!nif::get_binary(env, head, &binary)) { |
404 |
| - return xla::InvalidArgument("infeed operation expects a list of binaries"); |
| 400 | + for (const auto & buffer_bin : buffer_bins) { |
| 401 | + const char* data_ptr = const_cast<char*>(reinterpret_cast<char*>(buffer_bin.data)); |
| 402 | + buf_ptrs.push_back(data_ptr); |
405 | 403 | }
|
406 | 404 |
|
407 |
| - const char* data_ptr = const_cast<char*>(reinterpret_cast<char*>(binary.data)); |
408 |
| - xla::BorrowingLiteral literal(data_ptr, shape); |
409 |
| - |
410 |
| - EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client_->LookupDevice(device_id)); |
411 |
| - |
412 |
| - return device->TransferToInfeed(literal); |
| 405 | + auto shape = xla::ShapeUtil::MakeTupleShape(shapes); |
| 406 | + |
| 407 | + // Instead of pushing each buffer separately, we create a flat tuple |
| 408 | + // literal and push the whole group of buffers. |
| 409 | + // |
| 410 | + // On the CPU, XLA infeed reads buffers from a queue one at a time [1][2] |
| 411 | + // (or rather, the infeed operation is lowered to multiple queue reads), |
| 412 | + // hence pushing one at a time works fine. Pushing a flat tuple works |
| 413 | + // effectively the same, since it basically adds each element to the |
| 414 | + // queue [3]. |
| 415 | + // |
| 416 | + // On the GPU, XLA infeed reads only a single "literal" from a queue [4] |
| 417 | + // and expects it to carry all buffers for the given infeed operation. |
| 418 | + // Consequently, we need to push all buffers as a single literal. |
| 419 | + // |
| 420 | + // Given that a flat tuple works in both cases, we just do that. |
| 421 | + // |
| 422 | + // [1]: https://github.com/openxla/xla/blob/fd58925adee147d38c25a085354e15427a12d00a/xla/service/cpu/ir_emitter.cc#L449-L450 |
| 423 | + // [2]: https://github.com/openxla/xla/blob/fd58925adee147d38c25a085354e15427a12d00a/xla/service/cpu/cpu_runtime.cc#L222 |
| 424 | + // [3]: https://github.com/openxla/xla/blob/fd58925adee147d38c25a085354e15427a12d00a/xla/service/cpu/cpu_xfeed.cc#L178 |
| 425 | + // [4]: https://github.com/openxla/xla/blob/fd58925adee147d38c25a085354e15427a12d00a/xla/service/gpu/runtime/infeed_thunk.cc#L40-L41 |
| 426 | + xla::BorrowingLiteral literal(buf_ptrs, shape); |
| 427 | + |
| 428 | + EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client_->LookupDevice(xla::PjRtGlobalDeviceId(device_id))); |
| 429 | + |
| 430 | + xla::Status status = device->TransferToInfeed(literal); |
| 431 | + |
| 432 | + return status; |
413 | 433 | }
|
414 | 434 |
|
415 | 435 | xla::StatusOr<ERL_NIF_TERM> ExlaClient::TransferFromOutfeed(ErlNifEnv* env, int device_id, xla::Shape& shape) {
|
416 |
| - EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client_->LookupDevice(device_id)); |
| 436 | + EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client_->LookupDevice(xla::PjRtGlobalDeviceId(device_id))); |
417 | 437 |
|
418 | 438 | auto literal = std::make_shared<xla::Literal>(shape);
|
419 | 439 |
|
@@ -445,8 +465,11 @@ xla::StatusOr<ExlaClient*> GetGpuClient(double memory_fraction,
|
445 | 465 | .memory_fraction = memory_fraction,
|
446 | 466 | .preallocate = preallocate};
|
447 | 467 |
|
| 468 | + xla::GpuClientOptions client_options = { |
| 469 | + .allocator_config = allocator_config}; |
| 470 | + |
448 | 471 | EXLA_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtClient> client,
|
449 |
| - xla::GetStreamExecutorGpuClient(false, allocator_config, 0)); |
| 472 | + xla::GetStreamExecutorGpuClient(client_options)); |
450 | 473 |
|
451 | 474 | return new ExlaClient(std::move(client));
|
452 | 475 | }
|
|
0 commit comments