Skip to content

Commit 087845b

Browse files
committed
Simplify recur shared ids traversal
1 parent 9fe8f06 commit 087845b

File tree

1 file changed

+36
-38
lines changed

1 file changed

+36
-38
lines changed

exla/lib/exla/defn.ex

Lines changed: 36 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1823,61 +1823,59 @@ defmodule EXLA.Defn do
18231823
end
18241824
end
18251825

1826-
defp collect_arg?(_id, :parameter, _args, _shared_ids),
1826+
defp recur_shared_ids(
1827+
expr,
1828+
other_ids,
1829+
%{scope_ids: ids} = state,
1830+
cache
1831+
) do
1832+
{_, cache} =
1833+
Composite.reduce(expr, {%{}, cache}, fn node, acc ->
1834+
do_recur_shared_ids(node, state, acc, {ids, other_ids})
1835+
end)
1836+
1837+
cache
1838+
end
1839+
1840+
defp shared?(_id, :parameter, _args, _shared_ids),
18271841
do: true
18281842

18291843
# We never pass reference to tuples around, only through their elements,
18301844
# so if a tuple is in a predicate, then it all must be in a predicate.
1831-
defp collect_arg?(_id, :elem, [%T{data: %Expr{id: tuple_id}}, _pos], {parent_ids, sibling_ids})
1845+
defp shared?(_id, :elem, [%T{data: %Expr{id: tuple_id}}, _pos], {parent_ids, sibling_ids})
18321846
when is_map_key(parent_ids, tuple_id) or is_map_key(sibling_ids, tuple_id),
18331847
do: true
18341848

1835-
defp collect_arg?(id, _op, _args, {parent_ids, sibling_ids}),
1849+
defp shared?(id, _op, _args, {parent_ids, sibling_ids}),
18361850
do: is_map_key(parent_ids, id) or is_map_key(sibling_ids, id)
18371851

1838-
defp collect_args(%T{data: %Expr{id: id, op: op, args: args}} = expr, {cache, ids}, shared_ids) do
1852+
defp do_recur_shared_ids(
1853+
%T{data: %Expr{id: id, op: op, args: args}} = expr,
1854+
state,
1855+
{visited, cache},
1856+
shared_ids
1857+
) do
18391858
cond do
1840-
op == :constant or collect_arg?(id, op, args, shared_ids) ->
1841-
case ids do
1842-
%{^id => {_, _, new}} ->
1843-
{new, {cache, ids}}
1844-
1845-
%{} ->
1846-
i = map_size(ids)
1847-
param = Expr.parameter(expr, i)
1848-
{param, {Map.put(cache, id, param), Map.put(ids, id, {i, expr, param})}}
1849-
end
1859+
Map.has_key?(visited, id) ->
1860+
{visited, cache}
18501861

1851-
expr = Map.get(cache, id) ->
1852-
{expr, {cache, ids}}
1862+
op == :constant or shared?(id, op, args, shared_ids) ->
1863+
{_, cache} = recur_operator(expr, state, cache)
1864+
{Map.put(visited, id, true), cache}
18531865

18541866
true ->
1855-
{args, {cache, ids}} =
1856-
Tree.apply_args(expr, :scope, {cache, ids}, &collect_args(&1, &2, shared_ids))
1867+
{_, {visited, cache}} =
1868+
Tree.apply_args(
1869+
expr,
1870+
:scope,
1871+
{visited, cache},
1872+
&{&1, do_recur_shared_ids(&1, state, &2, shared_ids)}
1873+
)
18571874

1858-
expr = put_in(expr.data.args, args)
1859-
{expr, {Map.put(cache, id, expr), ids}}
1875+
{Map.put(visited, id, true), cache}
18601876
end
18611877
end
18621878

1863-
defp recur_shared_ids(
1864-
expr,
1865-
other_ids,
1866-
%{scope_ids: ids} = state,
1867-
cache
1868-
) do
1869-
{_, ids_args} =
1870-
Composite.reduce(expr, {%{}, %{}}, fn node, acc ->
1871-
{_, acc} = collect_args(node, acc, {ids, other_ids})
1872-
acc
1873-
end)
1874-
1875-
Enum.reduce(ids_args, cache, fn {_, {_, old, _}}, cache ->
1876-
{_, cache} = recur_operator(old, state, cache)
1877-
cache
1878-
end)
1879-
end
1880-
18811879
defp to_mlir_if_branch(region, expr, current_ids, state, cache) do
18821880
comp_state = %{state | scope_ids: current_ids}
18831881

0 commit comments

Comments
 (0)