@@ -506,7 +506,7 @@ defmodule Torchx.Backend do
506
506
507
507
result =
508
508
if axes == [ ] do
509
- aggregate_whole_tensor ( t , keep_axes , & Torchx . product / 1 )
509
+ aggregate_whole_tensor ( t , & Torchx . product / 1 )
510
510
else
511
511
aggregate_over_axes ( t , axes , keep_axes , & Torchx . product / 3 )
512
512
end
@@ -523,7 +523,7 @@ defmodule Torchx.Backend do
523
523
524
524
result =
525
525
if axes == [ ] do
526
- aggregate_whole_tensor ( t , keep_axes , & Torchx . any / 1 )
526
+ aggregate_whole_tensor ( t , & Torchx . any / 1 )
527
527
else
528
528
aggregate_over_axes ( t , axes , keep_axes , & Torchx . any / 3 )
529
529
end
@@ -538,7 +538,7 @@ defmodule Torchx.Backend do
538
538
539
539
result =
540
540
if axes == [ ] do
541
- aggregate_whole_tensor ( t , keep_axes , & Torchx . all / 1 )
541
+ aggregate_whole_tensor ( t , & Torchx . all / 1 )
542
542
else
543
543
aggregate_over_axes ( t , axes , keep_axes , & Torchx . all / 3 )
544
544
end
@@ -563,18 +563,10 @@ defmodule Torchx.Backend do
563
563
|> to_nx ( out )
564
564
end
565
565
566
- defp aggregate_whole_tensor ( t , keep_axes , fun ) when is_function ( fun , 1 ) do
567
- result =
568
- t
569
- |> from_nx ( )
570
- |> then ( fun )
571
-
572
- if keep_axes do
573
- shape = t . shape |> Tuple . delete_at ( - 1 ) |> Tuple . append ( 1 )
574
- Torchx . reshape ( result , shape )
575
- else
576
- result
577
- end
566
+ defp aggregate_whole_tensor ( t , fun ) when is_function ( fun , 1 ) do
567
+ t
568
+ |> from_nx ( )
569
+ |> then ( fun )
578
570
end
579
571
580
572
defp aggregate_over_axes ( t , axes , keep_axes , fun ) when is_function ( fun , 3 ) do
0 commit comments