Skip to content

Commit 22006fd

Browse files
committed
Track variables across equality in guards
1 parent e87b935 commit 22006fd

File tree

7 files changed

+263
-61
lines changed

7 files changed

+263
-61
lines changed

lib/elixir/lib/module/types.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ defmodule Module.Types do
332332
{return_type, context} =
333333
Expr.of_expr(body, Descr.term(), body, stack, context)
334334

335-
args_types = Pattern.of_domain(trees, context)
335+
args_types = Pattern.of_domain(trees, stack, context)
336336

337337
{type_index, inferred} =
338338
add_inferred(inferred, args_types, return_type, total - 1, [])

lib/elixir/lib/module/types/apply.ex

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,10 @@ defmodule Module.Types.Apply do
396396
common = intersection(left_type, right_type)
397397

398398
cond do
399+
# This check is incomplete. After all, we could have the number type nested
400+
# inside a tuple or a list and the comparison would still be valid.
401+
# However, nested comparison between distinct numbers is very uncommon,
402+
# so we only check the direct value here.
399403
empty?(common) and not (number_type?(left_type) and number_type?(right_type)) ->
400404
error = {:mismatched_comparison, left_type, right_type}
401405
remote_error(error, :erlang, name, 2, expr, stack, context)
@@ -479,7 +483,6 @@ defmodule Module.Types.Apply do
479483
remote_domain(mod, fun, args, expected, elem(expr, 1), stack, context)
480484
end
481485

482-
@number union(integer(), float())
483486
@empty_list empty_list()
484487
@non_empty_list non_empty_list(term())
485488
@empty_map empty_map()
@@ -546,6 +549,7 @@ defmodule Module.Types.Apply do
546549
:maybe_false -> {name in [:"/=", :"=/="], @atom_false}
547550
end
548551

552+
# This logic mirrors the code in `Pattern.of_pattern_tree`
549553
# If it is a singleton, we can always be precise
550554
if singleton?(literal_type) do
551555
expected = if polarity, do: literal_type, else: negation(literal_type)
@@ -562,7 +566,7 @@ defmodule Module.Types.Apply do
562566
# We are checking for `not x == 1` or similar, we can't say anything about x
563567
polarity == false -> term()
564568
# We are checking for `x == 1`, make sure x is integer or float
565-
number_type?(literal_type) and name in [:==, :"/="] -> union(literal_type, @number)
569+
name in [:==, :"/="] -> numberize(literal_type)
566570
# Otherwise we have the literal type as is
567571
true -> literal_type
568572
end
@@ -579,7 +583,10 @@ defmodule Module.Types.Apply do
579583
return_compare(name, left_type, right_type, boolean(), both_literal?, expr, stack, context)
580584
end
581585

582-
defp return_compare(name, left_type, right_type, result, skip_check?, expr, stack, context) do
586+
@doc """
587+
Computes the return type of the comparison application.
588+
"""
589+
def return_compare(name, left_type, right_type, result, skip_check?, expr, stack, context) do
583590
result = return(result, [left_type, right_type], stack)
584591

585592
cond do
@@ -589,6 +596,10 @@ defmodule Module.Types.Apply do
589596
name in [:==, :"/="] and number_type?(left_type) and number_type?(right_type) ->
590597
{result, context}
591598

599+
# This check is incomplete. After all, we could have the number type nested
600+
# inside a tuple or a list and the comparison would still be valid.
601+
# However, nested comparison between distinct numbers is very uncommon,
602+
# so we only check the direct value here.
592603
disjoint?(left_type, right_type) ->
593604
error = {:mismatched_comparison, left_type, right_type}
594605
remote_error(error, :erlang, name, 2, expr, stack, context)

lib/elixir/lib/module/types/descr.ex

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,10 @@ defmodule Module.Types.Descr do
381381
end
382382
end
383383

384+
@compile {:inline, pop_dynamic: 1}
385+
defp pop_dynamic(:term), do: {:term, :term}
386+
defp pop_dynamic(descr), do: Map.pop(descr, :dynamic, descr)
387+
384388
@compile {:inline, maybe_union: 2}
385389
defp maybe_union(nil, _fun), do: nil
386390
defp maybe_union(descr, fun), do: union(descr, fun.())
@@ -394,18 +398,16 @@ defmodule Module.Types.Descr do
394398
def union(other, none) when none == @none, do: other
395399

396400
def union(left, right) do
397-
left = unfold(left)
398-
right = unfold(right)
399401
is_gradual_left = gradual?(left)
400402
is_gradual_right = gradual?(right)
401403

402404
cond do
403405
is_gradual_left and not is_gradual_right ->
404-
right_with_dynamic = Map.put(right, :dynamic, right)
406+
right_with_dynamic = Map.put(unfold(right), :dynamic, right)
405407
union_static(left, right_with_dynamic)
406408

407409
is_gradual_right and not is_gradual_left ->
408-
left_with_dynamic = Map.put(left, :dynamic, left)
410+
left_with_dynamic = Map.put(unfold(left), :dynamic, left)
409411
union_static(left_with_dynamic, right)
410412

411413
true ->
@@ -436,18 +438,16 @@ defmodule Module.Types.Descr do
436438
def intersection(other, %{dynamic: :term}), do: dynamic(remove_optional(other))
437439

438440
def intersection(left, right) do
439-
left = unfold(left)
440-
right = unfold(right)
441441
is_gradual_left = gradual?(left)
442442
is_gradual_right = gradual?(right)
443443

444444
cond do
445445
is_gradual_left and not is_gradual_right ->
446-
right_with_dynamic = Map.put(right, :dynamic, right)
446+
right_with_dynamic = Map.put(unfold(right), :dynamic, right)
447447
intersection_static(left, right_with_dynamic)
448448

449449
is_gradual_right and not is_gradual_left ->
450-
left_with_dynamic = Map.put(left, :dynamic, left)
450+
left_with_dynamic = Map.put(unfold(left), :dynamic, left)
451451
intersection_static(left_with_dynamic, right)
452452

453453
true ->
@@ -480,12 +480,9 @@ defmodule Module.Types.Descr do
480480
def difference(left, :term), do: keep_optional(left)
481481

482482
def difference(left, right) do
483-
left = unfold(left)
484-
right = unfold(right)
485-
486483
if gradual?(left) or gradual?(right) do
487-
{left_dynamic, left_static} = Map.pop(left, :dynamic, left)
488-
{right_dynamic, right_static} = Map.pop(right, :dynamic, right)
484+
{left_dynamic, left_static} = pop_dynamic(left)
485+
{right_dynamic, right_static} = pop_dynamic(right)
489486
dynamic_part = difference_static(left_dynamic, right_static)
490487

491488
Map.put(difference_static(left_static, right_dynamic), :dynamic, dynamic_part)
@@ -494,7 +491,8 @@ defmodule Module.Types.Descr do
494491
end
495492
end
496493

497-
# For static types, the difference is component-wise.
494+
# For static types, the difference is component-wise
495+
defp difference_static(left, descr) when descr == %{}, do: left
498496
defp difference_static(left, :term), do: keep_optional(left)
499497

500498
defp difference_static(left, right) do
@@ -563,7 +561,7 @@ defmodule Module.Types.Descr do
563561
Compute the negation of a type.
564562
"""
565563
def negation(:term), do: none()
566-
def negation(%{} = descr), do: difference(unfolded_term(), descr)
564+
def negation(%{} = descr), do: difference(term(), descr)
567565

568566
@doc """
569567
Check if a type is empty.
@@ -604,6 +602,42 @@ defmodule Module.Types.Descr do
604602
defp empty_key?(:tuple, value), do: tuple_empty?(value)
605603
defp empty_key?(_, _value), do: false
606604

605+
@doc """
606+
Converts all floats or integers into numbers.
607+
"""
608+
def numberize(:term), do: :term
609+
def numberize(descr), do: numberize_each(descr, [:bitmap, :tuple, :map, :list, :dynamic])
610+
611+
defp numberize_each(descr, [key | keys]) do
612+
case descr do
613+
%{^key => val} -> %{descr | key => numberize(key, val)}
614+
%{} -> descr
615+
end
616+
|> numberize_each(keys)
617+
end
618+
619+
defp numberize_each(descr, []) do
620+
descr
621+
end
622+
623+
defp numberize(:dynamic, descr), do: numberize(descr)
624+
defp numberize(:bitmap, bitmap) when (bitmap &&& @bit_number) != 0, do: bitmap ||| @bit_number
625+
defp numberize(:bitmap, bitmap), do: bitmap
626+
627+
defp numberize(:map, bdd) do
628+
bdd_map(bdd, fn {tag, fields} ->
629+
{tag, fields |> Map.to_list() |> Map.new(fn {key, value} -> {key, numberize(value)} end)}
630+
end)
631+
end
632+
633+
defp numberize(:tuple, bdd) do
634+
bdd_map(bdd, fn {tag, fields} -> {tag, Enum.map(fields, &numberize/1)} end)
635+
end
636+
637+
defp numberize(:list, bdd) do
638+
bdd_map(bdd, fn {head, tail} -> {numberize(head), numberize(tail)} end)
639+
end
640+
607641
@doc """
608642
Returns if the type is a singleton.
609643
"""
@@ -834,11 +868,7 @@ defmodule Module.Types.Descr do
834868
Incompatible subtypes include `integer() or list()`, `dynamic() and atom()`.
835869
"""
836870
def compatible?(left, right) do
837-
{left_dynamic, left_static} =
838-
case left do
839-
:term -> {:term, :term}
840-
_ -> Map.pop(left, :dynamic, left)
841-
end
871+
{left_dynamic, left_static} = pop_dynamic(left)
842872

843873
right_dynamic =
844874
case right do
@@ -862,11 +892,7 @@ defmodule Module.Types.Descr do
862892
as we traverse the program.
863893
"""
864894
def compatible_intersection(left, right) do
865-
{left_dynamic, left_static} =
866-
case left do
867-
:term -> {:term, :term}
868-
_ -> Map.pop(left, :dynamic, left)
869-
end
895+
{left_dynamic, left_static} = pop_dynamic(left)
870896

871897
right_dynamic =
872898
case right do
@@ -1068,7 +1094,7 @@ defmodule Module.Types.Descr do
10681094
def atom_fetch(:term), do: :error
10691095

10701096
def atom_fetch(%{} = descr) do
1071-
{static_or_dynamic, static} = Map.pop(descr, :dynamic, descr)
1097+
{static_or_dynamic, static} = pop_dynamic(descr)
10721098

10731099
if atom_only?(static) do
10741100
case static_or_dynamic do

lib/elixir/lib/module/types/expr.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ defmodule Module.Types.Expr do
325325
{acc, context} =
326326
of_clauses_fun(clauses, domain, @pending, nil, :fn, stack, context, [], fn
327327
trees, body, context, acc ->
328-
args = Pattern.of_domain(trees, context)
328+
args = Pattern.of_domain(trees, stack, context)
329329
add_inferred(acc, args, body)
330330
end)
331331

0 commit comments

Comments
 (0)