Skip to content

Commit 01960d9

Browse files
Remove all handling of EXLA.Shape tuple (#1475)
1 parent b75902d commit 01960d9

File tree

11 files changed

+13
-156
lines changed

11 files changed

+13
-156
lines changed

exla/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ HEADERS = $(EXLA_DIR)/mlir/ops.h $(EXLA_DIR)/mlir/builder.h $(EXLA_DIR)/mlir/cus
6666
OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o
6767

6868

69-
NVCC_RESULT := $(shell which nvcc 2> NULL)
69+
NVCC_RESULT := $(shell which nvcc 2> /dev/null)
7070
NVCC_TEST := $(notdir $(NVCC_RESULT))
7171

7272
ifeq ($(NVCC_TEST),nvcc)

exla/c_src/exla/exla.cc

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -323,22 +323,6 @@ ERL_NIF_TERM make_shape(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
323323
return exla::nif::ok(env, exla::nif::make<xla::Shape>(env, shape));
324324
}
325325

326-
ERL_NIF_TERM make_tuple_shape(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
327-
if (argc != 1) {
328-
return exla::nif::error(env, "Bad argument count.");
329-
}
330-
331-
std::vector<xla::Shape> shapes;
332-
333-
if (!exla::nif::get_list<xla::Shape>(env, argv[0], shapes)) {
334-
return exla::nif::error(env, "Unable to get shapes.");
335-
}
336-
337-
xla::Shape shape = xla::ShapeUtil::MakeTupleShape(shapes);
338-
339-
return exla::nif::ok(env, exla::nif::make<xla::Shape>(env, shape));
340-
}
341-
342326
ERL_NIF_TERM make_token_shape(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
343327
if (argc != 0) {
344328
return exla::nif::error(env, "Bad argument count.");
@@ -823,7 +807,6 @@ static ErlNifFunc exla_funcs[] = {
823807
// Shape
824808
{"make_shape", 2, make_shape},
825809
{"make_token_shape", 0, make_token_shape},
826-
{"make_tuple_shape", 1, make_tuple_shape},
827810
{"get_shape_info", 1, get_shape_info},
828811
// Log Sink
829812
{"start_log_sink", 1, start_log_sink},

exla/c_src/exla/exla_client.cc

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -393,38 +393,6 @@ xla::Status ExlaClient::TransferToInfeed(ErlNifEnv* env,
393393
ERL_NIF_TERM data,
394394
const xla::Shape& shape,
395395
int device_id) {
396-
// Tuples need to be decomposed a bit
397-
if (shape.IsTuple()) {
398-
// unsupported right now
399-
if (xla::ShapeUtil::IsNestedTuple(shape)) {
400-
return xla::InvalidArgument("nested tuples are not supported in infeed operation");
401-
}
402-
403-
int num_elements = xla::ShapeUtil::TupleElementCount(shape);
404-
std::vector<const char*> buf_ptrs;
405-
buf_ptrs.reserve(num_elements);
406-
407-
ERL_NIF_TERM head, tail;
408-
while (enif_get_list_cell(env, data, &head, &tail)) {
409-
ErlNifBinary tmp_bin;
410-
if (!nif::get_binary(env, head, &tmp_bin)) {
411-
return xla::InvalidArgument("infeed operation expects a list of binaries");
412-
}
413-
414-
const char* data_ptr = const_cast<char*>(reinterpret_cast<char*>(tmp_bin.data));
415-
buf_ptrs.push_back(data_ptr);
416-
data = tail;
417-
}
418-
419-
xla::BorrowingLiteral literal(buf_ptrs, shape);
420-
421-
EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client_->LookupDevice(device_id));
422-
423-
xla::Status status = device->TransferToInfeed(literal);
424-
425-
return status;
426-
}
427-
428396
// Fast path to avoid any traversal when not sending Tuples
429397
ERL_NIF_TERM head, tail;
430398
if (!enif_get_list_cell(env, data, &head, &tail)) {

exla/c_src/exla/exla_nif_util.cc

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -479,15 +479,8 @@ int get_primitive_type(ErlNifEnv* env, ERL_NIF_TERM term, xla::PrimitiveType* ty
479479

480480
ERL_NIF_TERM make_shape_info(ErlNifEnv* env, xla::Shape shape) {
481481
if (shape.IsTuple()) {
482-
int element_count = xla::ShapeUtil::TupleElementCount(shape);
483-
std::vector<ERL_NIF_TERM> terms;
484-
terms.reserve(element_count);
485-
for (int i = 0; i < element_count; i++) {
486-
xla::Shape shape_elem = xla::ShapeUtil::GetTupleElementShape(shape, i);
487-
ERL_NIF_TERM shape_term = make<xla::Shape>(env, shape_elem);
488-
terms.push_back(shape_term);
489-
}
490-
return enif_make_list_from_array(env, &terms[0], element_count);
482+
std::cerr << "Unexpected tuple shape" << std::endl;
483+
exit(1);
491484
} else if (shape.IsArray()) {
492485
xla::PrimitiveType type = shape.element_type();
493486
absl::Span<const int64> dims = shape.dimensions();

exla/lib/exla/defn/stream.ex

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,7 @@ defmodule EXLA.Defn.Stream do
8888

8989
data_and_shapes =
9090
if client.platform == :host do
91-
# TODO: Remove first-clause once EXLA.OP is removed
92-
shapes =
93-
case send_shape do
94-
%EXLA.Shape{dtype: {:tuple, shapes}} -> shapes
95-
l when is_list(l) -> l
96-
end
97-
98-
Enum.zip(buffers, shapes)
91+
Enum.zip(buffers, send_shape)
9992
else
10093
[{buffers, send_shape}]
10194
end

exla/lib/exla/executable.ex

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ defmodule EXLA.Executable do
8787
shapes =
8888
Enum.flat_map(List.wrap(shapes), fn shape ->
8989
case shape do
90-
%Shape{dtype: {:tuple, shapes}} -> shapes
9190
shapes when is_list(shapes) -> shapes
9291
%Shape{} -> [shape]
9392
end
@@ -102,18 +101,8 @@ defmodule EXLA.Executable do
102101
end)
103102
end
104103

105-
defp strip_shape(%Shape{dtype: {:tuple, shapes}}) do
106-
subshapes = Enum.map(shapes, &strip_shape/1)
107-
%{dtype: {:tuple, subshapes}, dims: {length(subshapes)}}
108-
end
109-
110104
defp strip_shape(%Shape{dtype: dtype, dims: dims}), do: %{dtype: dtype, dims: dims}
111105

112-
defp reconstruct_shapes(%{dtype: {:tuple, shapes}}) do
113-
subshapes = Enum.map(shapes, &reconstruct_shapes/1)
114-
EXLA.Shape.make_tuple_shape(subshapes)
115-
end
116-
117106
defp reconstruct_shapes(%{dtype: dtype, dims: dims}) do
118107
EXLA.Shape.make_shape(dtype, dims)
119108
end

exla/lib/exla/nif.ex

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,6 @@ defmodule EXLA.NIF do
191191
def make_token_shape(),
192192
do: :erlang.nif_error(:undef)
193193

194-
def make_tuple_shape(_shapes),
195-
do: :erlang.nif_error(:undef)
196-
197194
def get_host_client(),
198195
do: :erlang.nif_error(:undef)
199196

exla/lib/exla/shape.ex

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,8 @@ defmodule EXLA.Shape do
1111

1212
@doc false
1313
def get_shape_info(ref) when is_reference(ref) do
14-
case EXLA.NIF.get_shape_info(ref) |> unwrap!() do
15-
{dims_term, type_str} ->
16-
%Shape{dims: dims_term, dtype: charlist_to_dtype(type_str), ref: ref}
17-
18-
children when is_list(children) ->
19-
children = Enum.map(children, &get_shape_info/1)
20-
%Shape{dims: {length(children)}, dtype: {:tuple, children}, ref: ref}
21-
end
14+
{dims_term, type_str} = EXLA.NIF.get_shape_info(ref) |> unwrap!()
15+
%Shape{dims: dims_term, dtype: charlist_to_dtype(type_str), ref: ref}
2216
end
2317

2418
@doc """
@@ -38,18 +32,6 @@ defmodule EXLA.Shape do
3832
%Shape{dims: {}, dtype: :token, ref: ref}
3933
end
4034

41-
@doc """
42-
Creates a tuple shape with the given shapes.
43-
"""
44-
def make_tuple_shape(shapes) when is_list(shapes) do
45-
refs =
46-
shapes
47-
|> Enum.map(& &1.ref)
48-
49-
ref = EXLA.NIF.make_tuple_shape(refs) |> unwrap!()
50-
%Shape{dims: {length(shapes)}, dtype: {:tuple, shapes}, ref: ref}
51-
end
52-
5335
defp validate_dims!(_dims, 0), do: :ok
5436

5537
defp validate_dims!(dims, i)
@@ -63,10 +45,6 @@ defmodule EXLA.Shape do
6345
@doc """
6446
Returns the shape size in bytes.
6547
"""
66-
def byte_size(%EXLA.Shape{dtype: {:tuple, shapes}}) do
67-
Enum.reduce(shapes, 0, &(byte_size(&1) + &2))
68-
end
69-
7048
def byte_size(%EXLA.Shape{dtype: {_, bit_size}, dims: dims}) do
7149
Tuple.product(dims) * div(bit_size, 8)
7250
end

exla/test/exla/executable_test.exs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ defmodule EXLA.ExecutableTest do
2323
t2 = BinaryBuffer.from_binary(<<1::32-native>>, Shape.make_shape({:s, 32}, {}))
2424

2525
assert [a = %DeviceBuffer{}] =
26-
run_one([t1, t2], [], Shape.make_tuple_shape([t1.shape]), fn b, x, y ->
26+
run_one([t1, t2], [], [t1.shape], fn b, x, y ->
2727
[Value.add(b, x, y)]
2828
end)
2929

@@ -53,7 +53,7 @@ defmodule EXLA.ExecutableTest do
5353
end)
5454

5555
assert [%DeviceBuffer{}] =
56-
run_one([t1, t2], [], Shape.make_tuple_shape([t1.shape]), fn b, x, y ->
56+
run_one([t1, t2], [], [t1.shape], fn b, x, y ->
5757
[Value.add(b, x, y)]
5858
end)
5959

@@ -88,7 +88,7 @@ defmodule EXLA.ExecutableTest do
8888
t2 = BinaryBuffer.from_binary(<<2::32-native>>, Shape.make_shape({:s, 32}, {}))
8989

9090
assert [a = %DeviceBuffer{}] =
91-
run_one([t1, t2], [], {t1.shape}, fn b, x, y ->
91+
run_one([t1, t2], [], [t1.shape], fn b, x, y ->
9292
[Value.add(b, x, y)]
9393
end)
9494

@@ -100,7 +100,7 @@ defmodule EXLA.ExecutableTest do
100100
t2 = BinaryBuffer.from_binary(<<2::32-native>>, Shape.make_shape({:s, 32}, {}))
101101

102102
assert [a = %DeviceBuffer{}, b = %DeviceBuffer{}] =
103-
run_one([t1, t2], [], Shape.make_tuple_shape([t1.shape, t2.shape]), fn _b, x, y ->
103+
run_one([t1, t2], [], [t1.shape, t2.shape], fn _b, x, y ->
104104
[x, y]
105105
end)
106106

@@ -117,7 +117,7 @@ defmodule EXLA.ExecutableTest do
117117
run_one(
118118
[t1, t2],
119119
[device_id: 1],
120-
EXLA.Shape.make_tuple_shape([t1.shape, t2.shape, t1.shape]),
120+
[t1.shape, t2.shape, t1.shape],
121121
fn b, x, y ->
122122
[x, y, Value.add(b, x, y)]
123123
end
@@ -158,7 +158,7 @@ defmodule EXLA.ExecutableFeedTest do
158158

159159
assert res =
160160
Task.async(fn ->
161-
run_one([], [], Shape.make_tuple_shape([Shape.make_token_shape()]), fn b ->
161+
run_one([], [], [Shape.make_token_shape()], fn b ->
162162
token = Value.create_token(b)
163163

164164
{new_token, [val]} = Value.infeed(token, t.shape)
@@ -183,7 +183,7 @@ defmodule EXLA.ExecutableFeedTest do
183183

184184
assert res =
185185
Task.async(fn ->
186-
run_one([], [], {token_shape, t.shape}, fn b ->
186+
run_one([], [], [token_shape, t.shape], fn b ->
187187
token = Value.create_token(b)
188188

189189
{token, [val]} = Value.infeed(token, t.shape)

exla/test/exla/shape_test.exs

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,4 @@ defmodule EXLA.ShapeTest do
1616
assert Shape.byte_size(shape) == 2
1717
end
1818
end
19-
20-
describe "make_tuple_shape/1" do
21-
test "creates tuple shape" do
22-
s1 = Shape.make_shape({:s, 32}, {5, 5, 5})
23-
s2 = Shape.make_shape({:bf, 16}, {})
24-
s3 = Shape.make_shape({:f, 32}, {1, 1})
25-
26-
shape = Shape.make_tuple_shape([s1, s2, s3])
27-
assert %Shape{dtype: {:tuple, [_, _, _]}, dims: {3}, ref: ref} = shape
28-
assert Shape.byte_size(shape) == 506
29-
assert %Shape{dtype: {:tuple, [_, _, _]}, dims: {3}, ref: ^ref} = Shape.get_shape_info(ref)
30-
end
31-
32-
test "creates nested tuples" do
33-
s1 = Shape.make_shape({:s, 32}, {5, 5, 5})
34-
s2 = Shape.make_shape({:bf, 16}, {})
35-
s3 = Shape.make_shape({:f, 32}, {1, 1})
36-
s4 = Shape.make_shape({:s, 32}, {1})
37-
t1 = Shape.make_tuple_shape([s1, s2, s3])
38-
39-
shape = Shape.make_tuple_shape([s4, t1])
40-
41-
assert %Shape{dtype: {:tuple, [_, %Shape{dtype: {:tuple, [_, _, _]}}]}, dims: {2}, ref: ref} =
42-
shape
43-
44-
assert Shape.byte_size(shape) == 510
45-
46-
assert %Shape{
47-
dtype: {:tuple, [_, %Shape{dtype: {:tuple, [_, _, _]}}]},
48-
dims: {2},
49-
ref: ^ref
50-
} = Shape.get_shape_info(ref)
51-
end
52-
end
5319
end

0 commit comments

Comments
 (0)