Skip to content

Commit 55d3330

Browse files
committed
Fix warnings in Torchx and EXLA
1 parent ed59615 commit 55d3330

File tree

2 files changed

+9
-17
lines changed

2 files changed

+9
-17
lines changed

exla/lib/exla/defn.ex

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -787,8 +787,8 @@ defmodule EXLA.Defn do
787787
transform = Keyword.fetch!(opts, :transform_a)
788788

789789
case Value.get_typespec(b).shape do
790-
{_} = b_shape ->
791-
b_shape = Tuple.append(b_shape, 1)
790+
{dim} ->
791+
b_shape = {dim, 1}
792792

793793
b =
794794
b

torchx/lib/torchx/backend.ex

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ defmodule Torchx.Backend do
506506

507507
result =
508508
if axes == [] do
509-
aggregate_whole_tensor(t, keep_axes, &Torchx.product/1)
509+
aggregate_whole_tensor(t, &Torchx.product/1)
510510
else
511511
aggregate_over_axes(t, axes, keep_axes, &Torchx.product/3)
512512
end
@@ -523,7 +523,7 @@ defmodule Torchx.Backend do
523523

524524
result =
525525
if axes == [] do
526-
aggregate_whole_tensor(t, keep_axes, &Torchx.any/1)
526+
aggregate_whole_tensor(t, &Torchx.any/1)
527527
else
528528
aggregate_over_axes(t, axes, keep_axes, &Torchx.any/3)
529529
end
@@ -538,7 +538,7 @@ defmodule Torchx.Backend do
538538

539539
result =
540540
if axes == [] do
541-
aggregate_whole_tensor(t, keep_axes, &Torchx.all/1)
541+
aggregate_whole_tensor(t, &Torchx.all/1)
542542
else
543543
aggregate_over_axes(t, axes, keep_axes, &Torchx.all/3)
544544
end
@@ -563,18 +563,10 @@ defmodule Torchx.Backend do
563563
|> to_nx(out)
564564
end
565565

566-
defp aggregate_whole_tensor(t, keep_axes, fun) when is_function(fun, 1) do
567-
result =
568-
t
569-
|> from_nx()
570-
|> then(fun)
571-
572-
if keep_axes do
573-
shape = t.shape |> Tuple.delete_at(-1) |> Tuple.append(1)
574-
Torchx.reshape(result, shape)
575-
else
576-
result
577-
end
566+
defp aggregate_whole_tensor(t, fun) when is_function(fun, 1) do
567+
t
568+
|> from_nx()
569+
|> then(fun)
578570
end
579571

580572
defp aggregate_over_axes(t, axes, keep_axes, fun) when is_function(fun, 3) do

0 commit comments

Comments
 (0)