@@ -1823,61 +1823,59 @@ defmodule EXLA.Defn do
1823
1823
end
1824
1824
end
1825
1825
1826
- defp collect_arg? ( _id , :parameter , _args , _shared_ids ) ,
1826
+ defp recur_shared_ids (
1827
+ expr ,
1828
+ other_ids ,
1829
+ % { scope_ids: ids } = state ,
1830
+ cache
1831
+ ) do
1832
+ { _ , cache } =
1833
+ Composite . reduce ( expr , { % { } , cache } , fn node , acc ->
1834
+ do_recur_shared_ids ( node , state , acc , { ids , other_ids } )
1835
+ end )
1836
+
1837
+ cache
1838
+ end
1839
+
1840
+ defp shared? ( _id , :parameter , _args , _shared_ids ) ,
1827
1841
do: true
1828
1842
1829
1843
# We never pass reference to tuples around, only through their elements,
1830
1844
# so if a tuple is in a predicate, then it all must be in a predicate.
1831
- defp collect_arg ?( _id , :elem , [ % T { data: % Expr { id: tuple_id } } , _pos ] , { parent_ids , sibling_ids } )
1845
+ defp shared ?( _id , :elem , [ % T { data: % Expr { id: tuple_id } } , _pos ] , { parent_ids , sibling_ids } )
1832
1846
when is_map_key ( parent_ids , tuple_id ) or is_map_key ( sibling_ids , tuple_id ) ,
1833
1847
do: true
1834
1848
1835
- defp collect_arg ?( id , _op , _args , { parent_ids , sibling_ids } ) ,
1849
+ defp shared ?( id , _op , _args , { parent_ids , sibling_ids } ) ,
1836
1850
do: is_map_key ( parent_ids , id ) or is_map_key ( sibling_ids , id )
1837
1851
1838
- defp collect_args ( % T { data: % Expr { id: id , op: op , args: args } } = expr , { cache , ids } , shared_ids ) do
1852
+ defp do_recur_shared_ids (
1853
+ % T { data: % Expr { id: id , op: op , args: args } } = expr ,
1854
+ state ,
1855
+ { visited , cache } ,
1856
+ shared_ids
1857
+ ) do
1839
1858
cond do
1840
- op == :constant or collect_arg? ( id , op , args , shared_ids ) ->
1841
- case ids do
1842
- % { ^ id => { _ , _ , new } } ->
1843
- { new , { cache , ids } }
1844
-
1845
- % { } ->
1846
- i = map_size ( ids )
1847
- param = Expr . parameter ( expr , i )
1848
- { param , { Map . put ( cache , id , param ) , Map . put ( ids , id , { i , expr , param } ) } }
1849
- end
1859
+ Map . has_key? ( visited , id ) ->
1860
+ { visited , cache }
1850
1861
1851
- expr = Map . get ( cache , id ) ->
1852
- { expr , { cache , ids } }
1862
+ op == :constant or shared? ( id , op , args , shared_ids ) ->
1863
+ { _ , cache } = recur_operator ( expr , state , cache )
1864
+ { Map . put ( visited , id , true ) , cache }
1853
1865
1854
1866
true ->
1855
- { args , { cache , ids } } =
1856
- Tree . apply_args ( expr , :scope , { cache , ids } , & collect_args ( & 1 , & 2 , shared_ids ) )
1867
+ { _ , { visited , cache } } =
1868
+ Tree . apply_args (
1869
+ expr ,
1870
+ :scope ,
1871
+ { visited , cache } ,
1872
+ & { & 1 , do_recur_shared_ids ( & 1 , state , & 2 , shared_ids ) }
1873
+ )
1857
1874
1858
- expr = put_in ( expr . data . args , args )
1859
- { expr , { Map . put ( cache , id , expr ) , ids } }
1875
+ { Map . put ( visited , id , true ) , cache }
1860
1876
end
1861
1877
end
1862
1878
1863
- defp recur_shared_ids (
1864
- expr ,
1865
- other_ids ,
1866
- % { scope_ids: ids } = state ,
1867
- cache
1868
- ) do
1869
- { _ , ids_args } =
1870
- Composite . reduce ( expr , { % { } , % { } } , fn node , acc ->
1871
- { _ , acc } = collect_args ( node , acc , { ids , other_ids } )
1872
- acc
1873
- end )
1874
-
1875
- Enum . reduce ( ids_args , cache , fn { _ , { _ , old , _ } } , cache ->
1876
- { _ , cache } = recur_operator ( old , state , cache )
1877
- cache
1878
- end )
1879
- end
1880
-
1881
1879
defp to_mlir_if_branch ( region , expr , current_ids , state , cache ) do
1882
1880
comp_state = % { state | scope_ids: current_ids }
1883
1881
0 commit comments