Skip to content

Commit e26112d

Browse files
Update XLA and fix infeed on the GPU (#1487)
Co-authored-by: José Valim <[email protected]>
1 parent 607ae46 commit e26112d

File tree

8 files changed

+82
-58
lines changed

8 files changed

+82
-58
lines changed

exla/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ endif
4141
LDFLAGS = -L$(XLA_EXTENSION_LIB) -lxla_extension -shared
4242

4343
ifeq ($(shell uname -s), Darwin)
44-
LDFLAGS += -flat_namespace -undefined suppress -rpath @loader_path/xla_extension/lib
44+
LDFLAGS += -flat_namespace -undefined dynamic_lookup -rpath @loader_path/xla_extension/lib
4545
else
4646
# Use a relative RPATH, so at runtime libexla.so looks for libxla_extension.so
4747
# in ./lib regardless of the absolute location. This way priv can be safely

exla/c_src/exla/exla.cc

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ ERL_NIF_TERM create_buffer_from_device_pointer(ErlNifEnv* env, int argc, const E
481481
ptr = result.first;
482482
}
483483

484-
EXLA_ASSIGN_OR_RETURN_NIF(xla::PjRtDevice * device, (*client)->client()->LookupDevice(device_id), env);
484+
EXLA_ASSIGN_OR_RETURN_NIF(xla::PjRtDevice * device, (*client)->client()->LookupDevice(xla::PjRtGlobalDeviceId(device_id)), env);
485485

486486
std::function<void()> on_delete_callback = []() {};
487487
EXLA_ASSIGN_OR_RETURN_NIF(std::unique_ptr<xla::PjRtBuffer> buffer, (*client)->client()->CreateViewOfDeviceBuffer(ptr, shape, device, on_delete_callback), env);
@@ -573,29 +573,40 @@ ERL_NIF_TERM transfer_to_infeed(ErlNifEnv* env, int argc, const ERL_NIF_TERM arg
573573
return exla::nif::error(env, "Unable to get device ID.");
574574
}
575575

576+
std::vector<ErlNifBinary> buffer_bins;
577+
std::vector<xla::Shape> shapes;
578+
576579
ERL_NIF_TERM head, tail;
577580
while (enif_get_list_cell(env, data, &head, &tail)) {
578581
const ERL_NIF_TERM* terms;
579582
int count;
580-
xla::Shape shape;
581583

582584
if (!enif_get_tuple(env, head, &count, &terms) && count != 2) {
583-
return exla::nif::error(env, "Unable to binary-shape tuple.");
585+
return exla::nif::error(env, "Unable to {binary, shape} tuple.");
584586
}
585587

588+
ErlNifBinary buffer_bin;
589+
if (!exla::nif::get_binary(env, terms[0], &buffer_bin)) {
590+
return exla::nif::error(env, "Unable to binary.");
591+
}
592+
593+
xla::Shape shape;
586594
if (!exla::nif::get_typespec_as_xla_shape(env, terms[1], &shape)) {
587595
return exla::nif::error(env, "Unable to get shape.");
588596
}
589597

590-
xla::Status transfer_status = (*client)->TransferToInfeed(env, terms[0], shape, device_id);
591-
592-
if (!transfer_status.ok()) {
593-
return exla::nif::error(env, transfer_status.message().data());
594-
}
598+
buffer_bins.push_back(buffer_bin);
599+
shapes.push_back(shape);
595600

596601
data = tail;
597602
}
598603

604+
xla::Status transfer_status = (*client)->TransferToInfeed(env, buffer_bins, shapes, device_id);
605+
606+
if (!transfer_status.ok()) {
607+
return exla::nif::error(env, transfer_status.message().data());
608+
}
609+
599610
return exla::nif::ok(env);
600611
}
601612

@@ -668,7 +679,7 @@ ERL_NIF_TERM copy_buffer_to_device(ErlNifEnv* env, int argc, const ERL_NIF_TERM
668679
}
669680

670681
EXLA_ASSIGN_OR_RETURN_NIF(xla::PjRtDevice * device,
671-
(*client)->client()->LookupDevice(device_id), env);
682+
(*client)->client()->LookupDevice(xla::PjRtGlobalDeviceId(device_id)), env);
672683
EXLA_ASSIGN_OR_RETURN_NIF(exla::ExlaBuffer * buf,
673684
(*buffer)->CopyToDevice(device), env);
674685

exla/c_src/exla/exla_client.cc

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "xla/pjrt/pjrt_c_api_client.h"
99
#include "xla/pjrt/pjrt_compiler.h"
1010
#include "xla/pjrt/tfrt_cpu_pjrt_client.h"
11+
#include "xla/shape_util.h"
1112

1213
namespace exla {
1314

@@ -61,10 +62,10 @@ xla::StatusOr<std::unique_ptr<xla::PjRtBuffer>> PjRtBufferFromBinary(xla::PjRtCl
6162
return xla::InvalidArgument("Expected buffer to be binary.");
6263
}
6364

64-
xla::PjRtClient::HostBufferSemantics semantics = xla::PjRtClient::HostBufferSemantics::kZeroCopy;
65+
xla::PjRtClient::HostBufferSemantics semantics = xla::PjRtClient::HostBufferSemantics::kImmutableZeroCopy;
6566
std::function<void()> on_done_with_host_buffer = [copy_env]() { enif_free_env(copy_env); };
6667

67-
EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client->LookupDevice(device_id));
68+
EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client->LookupDevice(xla::PjRtGlobalDeviceId(device_id)));
6869
EXLA_ASSIGN_OR_RETURN(auto buffer, client->BufferFromHostBuffer(
6970
binary.data, shape.element_type(), shape.dimensions(), std::nullopt, semantics, on_done_with_host_buffer, device));
7071

@@ -292,7 +293,7 @@ xla::StatusOr<ERL_NIF_TERM> ExlaExecutable::Run(ErlNifEnv* env,
292293
// executable, meaning we need to find the device corresponding to the specific device
293294
// id and execute on that device, we've already guaranteed this executable only has 1
294295
// replica
295-
EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client_->client()->LookupDevice(device_id));
296+
EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client_->client()->LookupDevice(xla::PjRtGlobalDeviceId(device_id)));
296297
// because this is a portable executable, it only has 1 replica and so we only need
297298
// to get the arguments at the first position of the input buffers
298299
std::vector<xla::PjRtBuffer*> portable_args = input_buffers.at(0);
@@ -390,30 +391,49 @@ xla::StatusOr<ExlaExecutable*> ExlaClient::Compile(const mlir::OwningOpRef<mlir:
390391
}
391392

392393
xla::Status ExlaClient::TransferToInfeed(ErlNifEnv* env,
393-
ERL_NIF_TERM data,
394-
const xla::Shape& shape,
394+
std::vector<ErlNifBinary> buffer_bins,
395+
std::vector<xla::Shape> shapes,
395396
int device_id) {
396-
// Fast path to avoid any traversal when not sending Tuples
397-
ERL_NIF_TERM head, tail;
398-
if (!enif_get_list_cell(env, data, &head, &tail)) {
399-
return xla::InvalidArgument("infeed operation expects a list of binaries");
400-
}
397+
std::vector<const char*> buf_ptrs;
398+
buf_ptrs.reserve(buffer_bins.size());
401399

402-
ErlNifBinary binary;
403-
if (!nif::get_binary(env, head, &binary)) {
404-
return xla::InvalidArgument("infeed operation expects a list of binaries");
400+
for (const auto & buffer_bin : buffer_bins) {
401+
const char* data_ptr = const_cast<char*>(reinterpret_cast<char*>(buffer_bin.data));
402+
buf_ptrs.push_back(data_ptr);
405403
}
406404

407-
const char* data_ptr = const_cast<char*>(reinterpret_cast<char*>(binary.data));
408-
xla::BorrowingLiteral literal(data_ptr, shape);
409-
410-
EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client_->LookupDevice(device_id));
411-
412-
return device->TransferToInfeed(literal);
405+
auto shape = xla::ShapeUtil::MakeTupleShape(shapes);
406+
407+
// Instead of pushing each buffer separately, we create a flat tuple
408+
// literal and push the whole group of buffers.
409+
//
410+
// On the CPU, XLA infeed reads buffers from a queue one at a time [1][2]
411+
// (or rather, the infeed operation is lowered to multiple queue reads),
412+
// hence pushing one at a time works fine. Pushing a flat tuple works
413+
// effectively the same, since it basically adds each element to the
414+
// queue [3].
415+
//
416+
// On the GPU, XLA infeed reads only a single "literal" from a queue [4]
417+
// and expects it to carry all buffers for the given infeed operation.
418+
// Consequently, we need to push all buffers as a single literal.
419+
//
420+
// Given that a flat tuple works in both cases, we just do that.
421+
//
422+
// [1]: https://github.com/openxla/xla/blob/fd58925adee147d38c25a085354e15427a12d00a/xla/service/cpu/ir_emitter.cc#L449-L450
423+
// [2]: https://github.com/openxla/xla/blob/fd58925adee147d38c25a085354e15427a12d00a/xla/service/cpu/cpu_runtime.cc#L222
424+
// [3]: https://github.com/openxla/xla/blob/fd58925adee147d38c25a085354e15427a12d00a/xla/service/cpu/cpu_xfeed.cc#L178
425+
// [4]: https://github.com/openxla/xla/blob/fd58925adee147d38c25a085354e15427a12d00a/xla/service/gpu/runtime/infeed_thunk.cc#L40-L41
426+
xla::BorrowingLiteral literal(buf_ptrs, shape);
427+
428+
EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client_->LookupDevice(xla::PjRtGlobalDeviceId(device_id)));
429+
430+
xla::Status status = device->TransferToInfeed(literal);
431+
432+
return status;
413433
}
414434

415435
xla::StatusOr<ERL_NIF_TERM> ExlaClient::TransferFromOutfeed(ErlNifEnv* env, int device_id, xla::Shape& shape) {
416-
EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client_->LookupDevice(device_id));
436+
EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client_->LookupDevice(xla::PjRtGlobalDeviceId(device_id)));
417437

418438
auto literal = std::make_shared<xla::Literal>(shape);
419439

@@ -445,8 +465,11 @@ xla::StatusOr<ExlaClient*> GetGpuClient(double memory_fraction,
445465
.memory_fraction = memory_fraction,
446466
.preallocate = preallocate};
447467

468+
xla::GpuClientOptions client_options = {
469+
.allocator_config = allocator_config};
470+
448471
EXLA_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtClient> client,
449-
xla::GetStreamExecutorGpuClient(false, allocator_config, 0));
472+
xla::GetStreamExecutorGpuClient(client_options));
450473

451474
return new ExlaClient(std::move(client));
452475
}

exla/c_src/exla/exla_client.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ class ExlaClient {
9292

9393
// TODO(seanmor5): This is device logic and should be refactored
9494
xla::Status TransferToInfeed(ErlNifEnv* env,
95-
ERL_NIF_TERM data,
96-
const xla::Shape& shape,
95+
std::vector<ErlNifBinary> buffer_bins,
96+
std::vector<xla::Shape> shapes,
9797
int device_id);
9898

9999
xla::StatusOr<ERL_NIF_TERM> TransferFromOutfeed(ErlNifEnv* env, int device_id, xla::Shape& shape);

exla/lib/exla/client.ex

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,19 +88,15 @@ defmodule EXLA.Client do
8888
@doc """
8989
Sends `data_and_typespecs` to device infeed.
9090
91-
`data_and_typespecs` must be a list of two element tuples where the
92-
first element is a binary or a flat list of binaries and the second
93-
element is a `EXLA.Typespec`.
91+
`data_and_typespecs` is a list of values corresponding to a single
92+
infeed operation. It must be a list of two element tuples where the
93+
first element is a binary and the second element is a `EXLA.Typespec`.
9494
"""
9595
def to_infeed(%EXLA.Client{ref: client}, device_id, data_and_typespecs)
9696
when is_list(data_and_typespecs) do
9797
data_and_typespecs =
98-
Enum.map(data_and_typespecs, fn
99-
{binary, typespec} when is_binary(binary) ->
100-
{[binary], EXLA.Typespec.nif_encode(typespec)}
101-
102-
{[binary | _] = data, typespec} when is_binary(binary) ->
103-
{data, EXLA.Typespec.nif_encode(typespec)}
98+
Enum.map(data_and_typespecs, fn {binary, typespec} when is_binary(binary) ->
99+
{binary, EXLA.Typespec.nif_encode(typespec)}
104100
end)
105101

106102
EXLA.NIF.transfer_to_infeed(client, device_id, data_and_typespecs) |> unwrap!()

exla/lib/exla/defn/stream.ex

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ defmodule EXLA.Defn.Stream do
22
@moduledoc false
33

44
keys =
5-
[:lock, :outfeed, :pid, :runner, :send, :send_typespec, :send_indexes] ++
5+
[:lock, :outfeed, :pid, :runner, :send, :send_typespecs, :send_indexes] ++
66
[:recv, :recv_length, :done, :client, :device_id]
77

88
@derive {Inspect, only: [:pid, :client, :device_id, :send, :recv]}
@@ -15,7 +15,7 @@ defmodule EXLA.Defn.Stream do
1515
runner,
1616
outfeed,
1717
send,
18-
send_typespec,
18+
send_typespecs,
1919
send_indexes,
2020
recv,
2121
recv_typespecs,
@@ -39,7 +39,7 @@ defmodule EXLA.Defn.Stream do
3939
outfeed: outfeed,
4040
lock: lock,
4141
send: send,
42-
send_typespec: send_typespec,
42+
send_typespecs: send_typespecs,
4343
send_indexes: send_indexes,
4444
recv: recv,
4545
recv_length: length(recv_typespecs),
@@ -64,7 +64,7 @@ defmodule EXLA.Defn.Stream do
6464
client: client,
6565
device_id: device_id,
6666
send: send,
67-
send_typespec: send_typespec,
67+
send_typespecs: send_typespecs,
6868
send_indexes: send_indexes
6969
} = stream
7070

@@ -86,17 +86,11 @@ defmodule EXLA.Defn.Stream do
8686
"""
8787
end
8888

89-
data_and_typespecs =
90-
if client.platform == :host do
91-
Enum.zip(buffers, send_typespec)
92-
else
93-
[{buffers, send_typespec}]
94-
end
95-
9689
pred = EXLA.Typespec.tensor({:pred, 8}, {})
90+
data_and_typespecs = Enum.zip(buffers, send_typespecs)
9791

98-
:ok =
99-
EXLA.Client.to_infeed(client, device_id, [{<<1::8-native>>, pred} | data_and_typespecs])
92+
:ok = EXLA.Client.to_infeed(client, device_id, [{<<1::8-native>>, pred}])
93+
:ok = EXLA.Client.to_infeed(client, device_id, data_and_typespecs)
10094
end
10195

10296
defp nx_to_io(container, indexes) do

exla/mix.exs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ defmodule EXLA.MixProject do
6666
# {:nx, "~> 0.7.1"},
6767
{:nx, path: "../nx"},
6868
{:telemetry, "~> 0.4.0 or ~> 1.0"},
69-
{:xla, "~> 0.6.0", runtime: false},
69+
{:xla, "~> 0.7.0", runtime: false},
7070
{:elixir_make, "~> 0.6", runtime: false},
7171
{:benchee, "~> 1.0", only: :dev},
7272
{:ex_doc, "~> 0.29", only: :docs},

exla/mix.lock

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"},
44
"deep_merge": {:hex, :deep_merge, "1.0.0", "b4aa1a0d1acac393bdf38b2291af38cb1d4a52806cf7a4906f718e1feb5ee961", [:mix], [], "hexpm", "ce708e5f094b9cd4e8f2be4f00d2f4250c4095be93f8cd6d018c753894885430"},
55
"earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"},
6-
"elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"},
6+
"elixir_make": {:hex, :elixir_make, "0.8.3", "d38d7ee1578d722d89b4d452a3e36bcfdc644c618f0d063b874661876e708683", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "5c99a18571a756d4af7a4d89ca75c28ac899e6103af6f223982f09ce44942cc9"},
77
"ex_doc": {:hex, :ex_doc, "0.31.1", "8a2355ac42b1cc7b2379da9e40243f2670143721dd50748bf6c3b1184dae2089", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.1", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "3178c3a407c557d8343479e1ff117a96fd31bafe52a039079593fb0524ef61b0"},
88
"makeup": {:hex, :makeup, "1.1.1", "fa0bc768698053b2b3869fa8a62616501ff9d11a562f3ce39580d60860c3a55e", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "5dc62fbdd0de44de194898b6710692490be74baa02d9d108bc29f007783b0b48"},
99
"makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"},
@@ -13,5 +13,5 @@
1313
"nx": {:hex, :nx, "0.7.1", "5f6376e3d18408116e8a84b8f4ac851fb07dfe61764a5410ebf0b5dcb69c1b7e", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "e3ddd6a3f2a9bac79c67b3933368c25bb5ec814a883fc68aba8fd8a236751777"},
1414
"statistex": {:hex, :statistex, "1.0.0", "f3dc93f3c0c6c92e5f291704cf62b99b553253d7969e9a5fa713e5481cd858a5", [:mix], [], "hexpm", "ff9d8bee7035028ab4742ff52fc80a2aa35cece833cf5319009b52f1b5a86c27"},
1515
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
16-
"xla": {:hex, :xla, "0.6.0", "67bb7695efa4a23b06211dc212de6a72af1ad5a9e17325e05e0a87e4c241feb8", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "dd074daf942312c6da87c7ed61b62fb1a075bced157f1cc4d47af2d7c9f44fb7"},
16+
"xla": {:hex, :xla, "0.7.0", "413880fb8f665d93636908092a409e549545e190b38b91107832e78379190d93", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "8eb5c5510e6737fd9e4860bfb0d8cafb13ab94b1b4123edd347562a71e19ec27"},
1717
}

0 commit comments

Comments
 (0)