Skip to content

Commit ac771f5

Browse files
committed
Pattern match on %Value{}
1 parent 28d2126 commit ac771f5

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

exla/lib/exla/defn.ex

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,7 +1210,7 @@ defmodule EXLA.Defn do
12101210
Value.dynamic_update_slice(tensor, slice, start_indices)
12111211
end
12121212

1213-
defp to_operator(:take, [%mod{} = tensor, indices, axis], _ans, _state) do
1213+
defp to_operator(:take, [%Value{} = tensor, indices, axis], _ans, _state) do
12141214
tensor_rank = tensor |> op_shape() |> tuple_size()
12151215
indices_rank = indices |> op_shape() |> tuple_size()
12161216
result_rank = tensor_rank - 1 + indices_rank
@@ -1221,7 +1221,7 @@ defmodule EXLA.Defn do
12211221
collapsed_slice_dims = [axis]
12221222
start_index_map = [axis]
12231223

1224-
mod.gather(
1224+
Value.gather(
12251225
tensor,
12261226
indices,
12271227
index_vector_dim,
@@ -1232,7 +1232,7 @@ defmodule EXLA.Defn do
12321232
)
12331233
end
12341234

1235-
defp to_operator(:take_along_axis, [%mod{} = tensor, indices, axis], _ans, state) do
1235+
defp to_operator(:take_along_axis, [%Value{} = tensor, indices, axis], _ans, state) do
12361236
indices_shape = op_shape(indices)
12371237
indices_rank = tuple_size(indices_shape)
12381238

@@ -1244,22 +1244,22 @@ defmodule EXLA.Defn do
12441244
collapsed_slice_dims = Enum.to_list(axes_range)
12451245
start_index_map = Enum.to_list(axes_range)
12461246

1247-
indices_exla_shape = mod.get_shape(indices)
1247+
indices_exla_shape = Value.get_shape(indices)
12481248

12491249
iotas =
12501250
Enum.map(axes_range, fn axis ->
1251-
mod.iota(state.builder, indices_exla_shape, axis)
1251+
Value.iota(state.builder, indices_exla_shape, axis)
12521252
end)
12531253

12541254
new_axis_shape = Tuple.append(indices_shape, 1)
12551255

12561256
indices =
12571257
iotas
12581258
|> List.replace_at(axis, indices)
1259-
|> Enum.map(&mod.reshape(&1, new_axis_shape))
1260-
|> mod.concatenate(indices_rank)
1259+
|> Enum.map(&Value.reshape(&1, new_axis_shape))
1260+
|> Value.concatenate(indices_rank)
12611261

1262-
mod.gather(
1262+
Value.gather(
12631263
tensor,
12641264
indices,
12651265
index_vector_dim,
@@ -1270,7 +1270,7 @@ defmodule EXLA.Defn do
12701270
)
12711271
end
12721272

1273-
defp to_operator(:gather, [%mod{} = tensor, indices, opts], _ans, _state) do
1273+
defp to_operator(:gather, [%Value{} = tensor, indices, opts], _ans, _state) do
12741274
axes = Keyword.fetch!(opts, :axes)
12751275
tensor_shape = op_shape(tensor)
12761276
tensor_rank = tuple_size(tensor_shape)
@@ -1284,7 +1284,7 @@ defmodule EXLA.Defn do
12841284

12851285
batch_size = tensor_rank - length(axes)
12861286
offset_dims = count_up(batch_size, batch_size)
1287-
mod.gather(tensor, indices, index_vector_dim, slice_sizes, offset_dims, axes, axes)
1287+
Value.gather(tensor, indices, index_vector_dim, slice_sizes, offset_dims, axes, axes)
12881288
end
12891289

12901290
defp to_operator(:reverse, [%Value{} = tensor, axes], _ans, _state) do
@@ -1339,7 +1339,7 @@ defmodule EXLA.Defn do
13391339
EXLA.Lib.argsort(state.builder, tensor, dimension, stable, comp, ans.type)
13401340
end
13411341

1342-
defp fft(exla_op, [%mod{} = tensor, opts], %{type: type}, state) do
1342+
defp fft(exla_op, [%Value{} = tensor, opts], %{type: type}, state) do
13431343
n = opts[:length]
13441344
axis = opts[:axis]
13451345
output_type = Nx.Type.to_complex(type)
@@ -1362,15 +1362,15 @@ defmodule EXLA.Defn do
13621362
|> List.to_tuple()
13631363

13641364
tensor
1365-
|> mod.transpose(permutation)
1365+
|> Value.transpose(permutation)
13661366
|> exla_op.([n])
1367-
|> mod.transpose(permutation)
1367+
|> Value.transpose(permutation)
13681368
else
13691369
exla_op.(tensor, [n])
13701370
end
13711371
end
13721372

1373-
defp fft2(exla_op, [%mod{} = tensor, opts], %{type: type}, state) do
1373+
defp fft2(exla_op, [%Value{} = tensor, opts], %{type: type}, state) do
13741374
[l1, l2] = lengths = opts[:lengths]
13751375
[ax1, ax2] = axes = opts[:axes]
13761376
output_type = Nx.Type.to_complex(type)
@@ -1399,9 +1399,9 @@ defmodule EXLA.Defn do
13991399
|> List.to_tuple()
14001400

14011401
tensor
1402-
|> mod.transpose(permutation)
1402+
|> Value.transpose(permutation)
14031403
|> exla_op.(lengths)
1404-
|> mod.transpose(permutation)
1404+
|> Value.transpose(permutation)
14051405
else
14061406
exla_op.(tensor, lengths)
14071407
end

0 commit comments

Comments
 (0)