Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/elixir/lib/module/types/apply.ex
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ defmodule Module.Types.Apply do
{union(type, fun_from_non_overlapping_clauses(clauses)), fallback?, context}

{{:infer, _, clauses}, context} when length(clauses) <= @max_clauses ->
{union(type, fun_from_overlapping_clauses(clauses)), fallback?, context}
{union(type, fun_from_inferred_clauses(clauses)), fallback?, context}

{_, context} ->
{type, true, context}
Expand Down Expand Up @@ -705,7 +705,7 @@ defmodule Module.Types.Apply do
result =
case info do
{:infer, _, clauses} when length(clauses) <= @max_clauses ->
fun_from_overlapping_clauses(clauses)
fun_from_inferred_clauses(clauses)

_ ->
dynamic(fun(arity))
Expand Down
135 changes: 101 additions & 34 deletions lib/elixir/lib/module/types/descr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ defmodule Module.Types.Descr do
@not_non_empty_list Map.delete(@term, :list)
@not_list Map.replace!(@not_non_empty_list, :bitmap, @bit_top - @bit_empty_list)

@empty_intersection [0, @none, []]
@empty_difference [0, []]
@empty_intersection [0, @none, [], :fun_bottom]
@empty_difference [0, [], :fun_bottom]

# Type definitions

Expand Down Expand Up @@ -137,16 +137,17 @@ defmodule Module.Types.Descr do
@doc """
Creates a function from overlapping function clauses.
"""
def fun_from_overlapping_clauses(args_clauses) do
def fun_from_inferred_clauses(args_clauses) do
domain_clauses =
Enum.reduce(args_clauses, [], fn {args, return}, acc ->
pivot_overlapping_clause(args_to_domain(args), return, acc)
domain = args |> Enum.map(&upper_bound/1) |> args_to_domain()
pivot_overlapping_clause(domain, upper_bound(return), acc)
end)

funs =
for {domain, return} <- domain_clauses,
args <- domain_to_args(domain),
do: fun(args, return)
do: fun(args, dynamic(return))

Enum.reduce(funs, &intersection/2)
end
Expand Down Expand Up @@ -200,19 +201,19 @@ defmodule Module.Types.Descr do
def domain_to_args(descr) do
case :maps.take(:dynamic, descr) do
:error ->
tuple_elim_negations_static(descr, &Function.identity/1)
unwrap_domain_tuple(descr, fn {:closed, elems} -> elems end)

{dynamic, static} ->
tuple_elim_negations_static(static, &Function.identity/1) ++
tuple_elim_negations_static(dynamic, fn elems -> Enum.map(elems, &dynamic/1) end)
unwrap_domain_tuple(static, fn {:closed, elems} -> elems end) ++
unwrap_domain_tuple(dynamic, fn {:closed, elems} -> Enum.map(elems, &dynamic/1) end)
end
end

defp tuple_elim_negations_static(%{tuple: dnf} = descr, transform) when map_size(descr) == 1 do
Enum.map(dnf, fn {:closed, elements} -> transform.(elements) end)
defp unwrap_domain_tuple(%{tuple: dnf} = descr, transform) when map_size(descr) == 1 do
Enum.map(dnf, transform)
end

defp tuple_elim_negations_static(descr, _transform) when descr == %{}, do: []
defp unwrap_domain_tuple(descr, _transform) when descr == %{}, do: []

defp domain_to_flat_args(domain, arity) do
case domain_to_args(domain) do
Expand Down Expand Up @@ -1170,6 +1171,7 @@ defmodule Module.Types.Descr do

static_arrows == [] ->
# TODO: We need to validate this within the theory
arguments = Enum.map(arguments, &upper_bound/1)
{:ok, dynamic(fun_apply_static(arguments, dynamic_arrows, false))}

true ->
Expand Down Expand Up @@ -1324,9 +1326,9 @@ defmodule Module.Types.Descr do
if subtype?(rets_reached, result), do: result, else: union(result, rets_reached)
end

defp aux_apply(result, input, returns_reached, [{dom, ret} | arrow_intersections]) do
defp aux_apply(result, input, returns_reached, [{args, ret} | arrow_intersections]) do
# Calculate the part of the input not covered by this arrow's domain
dom_subtract = difference(input, args_to_domain(dom))
dom_subtract = difference(input, args_to_domain(args))

# Refine the return type by intersecting with this arrow's return type
ret_refine = intersection(returns_reached, ret)
Expand Down Expand Up @@ -1423,7 +1425,7 @@ defmodule Module.Types.Descr do
# determines emptiness.
length(neg_arguments) == positive_arity and
subtype?(args_to_domain(neg_arguments), positive_domain) and
phi_starter(neg_arguments, negation(neg_return), positives)
phi_starter(neg_arguments, neg_return, positives)
end)
end
end
Expand Down Expand Up @@ -1461,27 +1463,75 @@ defmodule Module.Types.Descr do
#
# See [Castagna and Lanvin (2024)](https://arxiv.org/abs/2408.14345), Theorem 4.2.
defp phi_starter(arguments, return, positives) do
n = length(arguments)
# Arity mismatch: if there is one positive function with a different arity,
# then it cannot be a subtype of the (arguments->type) functions.
if Enum.any?(positives, fn {args, _ret} -> length(args) != n end) do
false
# Optimization: When all positive functions have non-empty domains,
# we can simplify the phi function check to a direct subtyping test.
# This avoids the expensive recursive phi computation by checking only that applying the
# input to the positive intersection yields a subtype of the return
if all_non_empty_domains?([{arguments, return} | positives]) do
fun_apply_static(arguments, [positives], false)
|> subtype?(return)
else
arguments = Enum.map(arguments, &{false, &1})
phi(arguments, {false, return}, positives)
n = length(arguments)
# Arity mismatch: functions with different arities cannot be subtypes
# of the target function type (arguments -> return)
if Enum.any?(positives, fn {args, _ret} -> length(args) != n end) do
false
else
# Initialize memoization cache for the recursive phi computation
arguments = Enum.map(arguments, &{false, &1})
{result, _cache} = phi(arguments, {false, negation(return)}, positives, %{})
result
end
end
end

defp phi(args, {b, t}, []) do
Enum.any?(args, fn {bool, typ} -> bool and empty?(typ) end) or (b and empty?(t))
defp phi(args, {b, t}, [], cache) do
{Enum.any?(args, fn {bool, typ} -> bool and empty?(typ) end) or (b and empty?(t)), cache}
end

defp phi(args, {b, ret}, [{arguments, return} | rest_positive]) do
phi(args, {true, intersection(ret, return)}, rest_positive) and
Enum.all?(Enum.with_index(arguments), fn {type, index} ->
List.update_at(args, index, fn {_, arg} -> {true, difference(arg, type)} end)
|> phi({b, ret}, rest_positive)
end)
defp phi(args, {b, ret}, [{arguments, return} | rest_positive], cache) do
# Create cache key from function arguments
cache_key = {args, {b, ret}, [{arguments, return} | rest_positive]}

case Map.get(cache, cache_key) do
nil ->
# Compute result and cache it
{result1, cache} = phi(args, {true, intersection(ret, return)}, rest_positive, cache)

if not result1 do
# Store false result in cache
cache = Map.put(cache, cache_key, false)
{false, cache}
else
# This doesn't stop if one intermediate result is false?
{result2, cache} =
Enum.with_index(arguments)
|> Enum.reduce_while({true, cache}, fn {type, index}, {acc_result, acc_cache} ->
{new_result, new_cache} =
List.update_at(args, index, fn {_, arg} -> {true, difference(arg, type)} end)
|> phi({b, ret}, rest_positive, acc_cache)

if new_result do
{:cont, {acc_result and new_result, new_cache}}
else
{:halt, {false, new_cache}}
end
end)

result = result1 and result2
# Store result in cache
cache = Map.put(cache, cache_key, result)
{result, cache}
end

cached_result ->
# Return cached result
{cached_result, cache}
end
end

defp all_non_empty_domains?(positives) do
Enum.all?(positives, fn {args, _ret} -> not empty?(args_to_domain(args)) end)
end

defp fun_union(bdd1, bdd2) do
Expand Down Expand Up @@ -1828,6 +1878,10 @@ defmodule Module.Types.Descr do
# b) If only the last type differs, subtracts it
# 3. Base case: adds dnf2 type to negations of dnf1 type
# The result may be larger than the initial dnf1, which is maintained in the accumulator.
defp list_difference(_, dnf) when dnf == @non_empty_list_top do
0
end

defp list_difference(dnf1, dnf2) do
Enum.reduce(dnf2, dnf1, fn {t2, last2, negs2}, acc_dnf1 ->
last2 = list_tail_unfold(last2)
Expand Down Expand Up @@ -1855,6 +1909,8 @@ defmodule Module.Types.Descr do
end)
end

defp list_empty?(@non_empty_list_top), do: false

defp list_empty?(dnf) do
Enum.all?(dnf, fn {list_type, last_type, negs} ->
last_type = list_tail_unfold(last_type)
Expand Down Expand Up @@ -2115,9 +2171,6 @@ defmodule Module.Types.Descr do

defp dynamic_to_quoted(descr, opts) do
cond do
descr == %{} ->
[]

# We check for :term literally instead of using term_type?
# because we check for term_type? in to_quoted before we
# compute the difference(dynamic, static).
Expand All @@ -2127,6 +2180,9 @@ defmodule Module.Types.Descr do
single = indivisible_bitmap(descr, opts) ->
[single]

empty?(descr) ->
[]

true ->
case non_term_type_to_quoted(descr, opts) do
{:none, _meta, []} = none -> [none]
Expand Down Expand Up @@ -2395,6 +2451,10 @@ defmodule Module.Types.Descr do
if empty?(type), do: throw(:empty), else: type
end

defp map_difference(_, dnf) when dnf == @map_top do
0
end

defp map_difference(dnf1, dnf2) do
Enum.reduce(dnf2, dnf1, fn
# Optimization: we are removing an open map with one field.
Expand Down Expand Up @@ -3045,10 +3105,15 @@ defmodule Module.Types.Descr do
zip_non_empty_intersection!(rest1, rest2, [non_empty_intersection!(type1, type2) | acc])
end

defp tuple_difference(_, dnf) when dnf == @tuple_top do
0
end

defp tuple_difference(dnf1, dnf2) do
Enum.reduce(dnf2, dnf1, fn {tag2, elements2}, dnf1 ->
Enum.reduce(dnf1, [], fn {tag1, elements1}, acc ->
tuple_eliminate_single_negation(tag1, elements1, {tag2, elements2}) ++ acc
tuple_eliminate_single_negation(tag1, elements1, {tag2, elements2})
|> tuple_union(acc)
end)
end)
end
Expand All @@ -3063,8 +3128,10 @@ defmodule Module.Types.Descr do
if (tag == :closed and n < m) or (neg_tag == :closed and n > m) do
[{tag, elements}]
else
tuple_elim_content([], tag, elements, neg_elements) ++
tuple_union(
tuple_elim_content([], tag, elements, neg_elements),
tuple_elim_size(n, m, tag, elements, neg_tag)
)
end
end

Expand Down
8 changes: 6 additions & 2 deletions lib/elixir/lib/module/types/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ defmodule Module.Types.Expr do
add_inferred(acc, args, body)
end)

{fun_from_overlapping_clauses(acc), context}
{fun_from_inferred_clauses(acc), context}
end
end

Expand Down Expand Up @@ -461,7 +461,11 @@ defmodule Module.Types.Expr do
{args_types, context} =
Enum.map_reduce(args, context, &of_expr(&1, @pending, &1, stack, &2))

Apply.fun_apply(fun_type, args_types, call, stack, context)
if stack.mode == :traversal do
{dynamic(), context}
else
Apply.fun_apply(fun_type, args_types, call, stack, context)
end
end

def of_expr({{:., _, [callee, key_or_fun]}, meta, []} = call, expected, expr, stack, context)
Expand Down
Loading