@@ -598,8 +598,49 @@ defmodule EXLA.Defn do
598
598
599
599
defp cached_recur_operator (
600
600
:optional ,
601
- % T { data: % Expr { args: [ % { data: % { op: :top_k , args: [ tensor , opts ] } } , expr , _callback ] } } =
602
- _out ,
601
+ % T {
602
+ data: % Expr {
603
+ args: [ % { data: % { op: :take , args: [ tensor , indices , opts ] } } , expr , _callback ]
604
+ }
605
+ } ,
606
+ state ,
607
+ cache
608
+ ) do
609
+ axis = opts [ :axis ]
610
+ { tensor , cache } = recur_operator ( tensor , state , cache ) |> unwrap_single_tensor! ( )
611
+ { indices , cache } = recur_operator ( indices , state , cache ) |> unwrap_single_tensor! ( )
612
+
613
+ tensor_rank = tensor |> op_shape ( ) |> tuple_size ( )
614
+ indices_rank = indices |> op_shape ( ) |> tuple_size ( )
615
+ result_rank = tensor_rank - 1 + indices_rank
616
+
617
+ index_vector_dim = indices_rank
618
+ slice_sizes = tensor |> op_shape ( ) |> put_elem ( axis , 1 ) |> Tuple . to_list ( )
619
+
620
+ { left , right } = result_rank |> axes_for_rank ( ) |> Enum . split ( axis )
621
+ offset_dims = left ++ Enum . drop ( right , indices_rank )
622
+
623
+ collapsed_slice_dims = [ axis ]
624
+ start_index_map = [ axis ]
625
+
626
+ result =
627
+ Value . gather (
628
+ tensor ,
629
+ indices ,
630
+ index_vector_dim ,
631
+ slice_sizes ,
632
+ offset_dims ,
633
+ collapsed_slice_dims ,
634
+ start_index_map ,
635
+ expr_to_typespec ( expr )
636
+ )
637
+
638
+ { result , cache }
639
+ end
640
+
641
+ defp cached_recur_operator (
642
+ :optional ,
643
+ % T { data: % Expr { args: [ % { data: % { op: :top_k , args: [ tensor , opts ] } } , expr , _callback ] } } ,
603
644
state ,
604
645
cache
605
646
) do
@@ -612,26 +653,24 @@ defmodule EXLA.Defn do
612
653
613
654
defp cached_recur_operator (
614
655
:optional ,
615
- % T { data: % Expr { args: [ % { data: % { op: :fft2 , args: [ tensor , opts ] } } , _expr , _callback ] } } =
616
- out ,
656
+ % T { data: % Expr { args: [ % { data: % { op: :fft2 , args: [ tensor , opts ] } } , expr , _callback ] } } ,
617
657
state ,
618
658
cache
619
659
) do
620
660
{ tensor , cache } = recur_operator ( tensor , state , cache ) |> unwrap_single_tensor! ( )
621
661
622
- { fft2 ( & Value . fft ( & 1 , :fft , & 2 , & 3 ) , [ tensor , opts ] , out , state ) , cache }
662
+ { fft2 ( & Value . fft ( & 1 , :fft , & 2 , & 3 ) , [ tensor , opts ] , expr , state ) , cache }
623
663
end
624
664
625
665
defp cached_recur_operator (
626
666
:optional ,
627
- % T { data: % Expr { args: [ % { data: % { op: :ifft2 , args: [ tensor , opts ] } } , _expr , _callback ] } } =
628
- out ,
667
+ % T { data: % Expr { args: [ % { data: % { op: :ifft2 , args: [ tensor , opts ] } } , expr , _callback ] } } ,
629
668
state ,
630
669
cache
631
670
) do
632
671
{ tensor , cache } = recur_operator ( tensor , state , cache ) |> unwrap_single_tensor! ( )
633
672
634
- { fft2 ( & Value . fft ( & 1 , :ifft , & 2 , & 3 ) , [ tensor , opts ] , out , state ) , cache }
673
+ { fft2 ( & Value . fft ( & 1 , :ifft , & 2 , & 3 ) , [ tensor , opts ] , expr , state ) , cache }
635
674
end
636
675
637
676
defp cached_recur_operator ( :optional , % T { data: % Expr { args: args } } , state , cache ) do
0 commit comments