@@ -153,9 +153,6 @@ defmodule Nx.Defn.Grad do
153
153
defp reduce_args ( :take_along_axis , % { data: % { args: [ arg | _ ] } } , acc , fun ) ,
154
154
do: fun . ( arg , acc )
155
155
156
- defp reduce_args ( :take , % { data: % { args: [ arg | _ ] } } , acc , fun ) ,
157
- do: fun . ( arg , acc )
158
-
159
156
defp reduce_args ( :gather , % { data: % { args: [ arg | _ ] } } , acc , fun ) ,
160
157
do: fun . ( arg , acc )
161
158
@@ -704,69 +701,6 @@ defmodule Nx.Defn.Grad do
704
701
[ { t , g } ]
705
702
end
706
703
707
- defp grad ( :take , [ t , i , axis ] , _ans , g ) do
708
- axes_range = 0 .. ( Nx . rank ( t ) - 1 ) // 1
709
-
710
- indices_shape =
711
- axes_range
712
- |> Enum . flat_map ( fn
713
- ^ axis -> Tuple . to_list ( i . shape )
714
- _ -> [ 1 ]
715
- end )
716
- |> List . to_tuple ( )
717
-
718
- idx_tiling =
719
- t . shape
720
- |> Tuple . to_list ( )
721
- |> Enum . with_index ( fn
722
- _x , ^ axis ->
723
- List . duplicate ( 1 , Nx . rank ( i ) )
724
-
725
- x , _ ->
726
- x
727
- end )
728
- |> List . flatten ( )
729
-
730
- num_elements = Tuple . product ( g . shape )
731
-
732
- indices_for_axis =
733
- i
734
- |> Nx . reshape ( indices_shape )
735
- |> Nx . tile ( idx_tiling )
736
-
737
- axis_offset = Nx . rank ( i ) - 1
738
-
739
- indices =
740
- axes_range
741
- |> Enum . map ( fn
742
- ^ axis ->
743
- indices_for_axis
744
- |> Nx . reshape ( { num_elements , 1 } )
745
-
746
- current when current < axis ->
747
- indices_for_axis
748
- |> Nx . shape ( )
749
- |> Nx . iota ( axis: current )
750
- |> Nx . reshape ( { num_elements , 1 } )
751
-
752
- current when current > axis ->
753
- indices_for_axis
754
- |> Nx . shape ( )
755
- |> Nx . iota ( axis: current + axis_offset )
756
- |> Nx . reshape ( { num_elements , 1 } )
757
- end )
758
- |> Nx . concatenate ( axis: 1 )
759
-
760
- updates = Nx . reshape ( g , { num_elements } )
761
-
762
- g =
763
- t
764
- |> Expr . broadcast ( 0 , Nx . shape ( t ) , Nx . axes ( t ) )
765
- |> Nx . indexed_add ( indices , updates )
766
-
767
- [ { t , g } ]
768
- end
769
-
770
704
defp grad ( :gather , [ t , i , opts ] , _ans , g ) do
771
705
i_axes = opts [ :axes ]
772
706
i_shape = i . shape
0 commit comments