Skip to content

Commit 4a3f61f

Browse files
committed
Implement type checking across multiple clauses
1 parent 809971a commit 4a3f61f

File tree

3 files changed

+90
-55
lines changed

3 files changed

+90
-55
lines changed

lib/elixir/lib/module/types/descr.ex

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ defmodule Module.Types.Descr do
159159
def gradual?(:term), do: false
160160
def gradual?(descr), do: is_map_key(descr, :dynamic)
161161

162+
def only_gradual?(%{dynamic: _} = descr), do: map_size(descr) == 1
163+
def only_gradual?(_), do: false
164+
162165
@doc """
163166
Make a whole type dynamic.
164167

lib/elixir/lib/module/types/of.ex

Lines changed: 80 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ defmodule Module.Types.Of do
408408
{:erlang, :map_size, [{[open_map()], integer()}]},
409409
{:erlang, :node, [{[], atom()}]},
410410
{:erlang, :node, [{[pid() |> union(reference()) |> union(port())], atom()}]},
411+
{:erlang, :not, [{[atom([false])], atom([true])}, {[atom([true])], atom([false])}]},
411412
{:erlang, :rem, [{[integer(), integer()], integer()}]},
412413
{:erlang, :round, [{[union(integer(), float())], integer()}]},
413414
{:erlang, :self, [{[], pid()}]},
@@ -438,11 +439,25 @@ defmodule Module.Types.Of do
438439
[arity] = Enum.map(clauses, fn {args, _return} -> length(args) end) |> Enum.uniq()
439440
true = Code.ensure_loaded?(mod) and function_exported?(mod, fun, arity)
440441

442+
domain_clauses =
443+
case clauses do
444+
[_] ->
445+
{:strong, nil, clauses}
446+
447+
_ ->
448+
domain =
449+
clauses
450+
|> Enum.map(fn {args, _} -> args end)
451+
|> Enum.zip_with(fn types -> Enum.reduce(types, &union/2) end)
452+
453+
{:strong, domain, clauses}
454+
end
455+
441456
defp remote(unquote(mod), unquote(fun), unquote(arity)),
442-
do: unquote(Macro.escape(clauses))
457+
do: unquote(Macro.escape(domain_clauses))
443458
end
444459

445-
defp remote(_mod, _fun, _arity), do: []
460+
defp remote(_mod, _fun, _arity), do: :none
446461

447462
@doc """
448463
Checks a module is a valid remote.
@@ -457,23 +472,22 @@ defmodule Module.Types.Of do
457472
and they can only be converted into arrows by computing the union
458473
of all arguments.
459474
460-
* `{:strong, clauses}` - clauses from signatures. So far these are
461-
strong arrows with non-overlapping domains. If you find one matching
462-
clause, you can stop looking for others.
475+
* `{:strong, domain or nil, clauses}` - clauses from signatures. So far
476+
these are strong arrows with non-overlapping domains
463477
464478
"""
465479
def remote(module, fun, arity, meta, stack, context) when is_atom(module) do
466480
if Keyword.get(meta, :runtime_module, false) do
467481
{:none, context}
468482
else
469483
case remote(module, fun, arity) do
470-
[] -> {:none, check_export(module, fun, arity, meta, stack, context)}
471-
clauses -> {{:strong, clauses}, context}
484+
:none -> {:none, check_export(module, fun, arity, meta, stack, context)}
485+
clauses -> {clauses, context}
472486
end
473487
end
474488
end
475489

476-
# TODO: {:erlang, :not, [{[atom([false])], atom([true])}, {[atom([true])], atom([false])}]},
490+
# TODO: Fix ordering of tuple operations
477491

478492
def apply(:erlang, :element, [_, tuple], {_, meta, [index, _]} = expr, stack, context)
479493
when is_integer(index) do
@@ -482,7 +496,7 @@ defmodule Module.Types.Of do
482496
{value_type, context}
483497

484498
:badtuple ->
485-
{error_type(), badapply_error(expr, [integer(), tuple], stack, context)}
499+
{error_type(), to_badapply_error(expr, [integer(), tuple], stack, context)}
486500

487501
reason ->
488502
{error_type(), error({reason, expr, tuple, index - 1, context}, meta, stack, context)}
@@ -503,7 +517,7 @@ defmodule Module.Types.Of do
503517
{value_type, context}
504518

505519
:badtuple ->
506-
{error_type(), badapply_error(expr, [integer(), tuple, value], stack, context)}
520+
{error_type(), to_badapply_error(expr, [integer(), tuple, value], stack, context)}
507521

508522
reason ->
509523
{error_type(), error({reason, expr, tuple, index - 2, context}, meta, stack, context)}
@@ -517,7 +531,7 @@ defmodule Module.Types.Of do
517531
{value_type, context}
518532

519533
:badtuple ->
520-
{error_type(), badapply_error(expr, [integer(), tuple], stack, context)}
534+
{error_type(), to_badapply_error(expr, [integer(), tuple], stack, context)}
521535

522536
reason ->
523537
{error_type(), error({reason, expr, tuple, index - 1, context}, meta, stack, context)}
@@ -535,7 +549,7 @@ defmodule Module.Types.Of do
535549
{value_type, context}
536550

537551
:badnonemptylist ->
538-
{error_type(), badapply_error(expr, [list], stack, context)}
552+
{error_type(), to_badapply_error(expr, [list], stack, context)}
539553
end
540554
end
541555

@@ -545,7 +559,7 @@ defmodule Module.Types.Of do
545559
{value_type, context}
546560

547561
:badnonemptylist ->
548-
{error_type(), badapply_error(expr, [list], stack, context)}
562+
{error_type(), to_badapply_error(expr, [list], stack, context)}
549563
end
550564
end
551565

@@ -572,7 +586,7 @@ defmodule Module.Types.Of do
572586
if name in [:min, :max] do
573587
{union(left, right), context}
574588
else
575-
{comparison_return(boolean(), args_types, stack), context}
589+
{remote_return(boolean(), args_types, stack), context}
576590
end
577591
end
578592

@@ -591,7 +605,7 @@ defmodule Module.Types.Of do
591605
context
592606
end
593607

594-
{comparison_return(boolean(), args_types, stack), context}
608+
{remote_return(boolean(), args_types, stack), context}
595609
end
596610

597611
def apply(mod, name, args_types, expr, stack, context) do
@@ -608,14 +622,14 @@ defmodule Module.Types.Of do
608622
{:ok, type} ->
609623
{type, context}
610624

611-
{:error, clauses} ->
612-
error = {:badapply, expr, args_types, clauses, context}
625+
{:error, domain, clauses} ->
626+
error = {:badapply, expr, args_types, domain, clauses, context}
613627
{error_type(), error(error, elem(expr, 1), stack, context)}
614628
end
615629
end
616630
end
617631

618-
defp comparison_return(type, args_types, stack) do
632+
defp remote_return(type, args_types, stack) do
619633
cond do
620634
stack.mode == :static -> type
621635
Enum.any?(args_types, &gradual?/1) -> dynamic(type)
@@ -627,37 +641,45 @@ defmodule Module.Types.Of do
627641
{:ok, dynamic()}
628642
end
629643

630-
defp apply_remote({:strong, clauses}, args_types, stack) do
631-
if Enum.any?(args_types) do
632-
returns =
633-
for({expected, return} <- clauses, zip_compatible?(args_types, expected), do: return)
644+
defp apply_remote({:strong, nil, [{expected, return}] = clauses}, args_types, stack) do
645+
# Optimize single clauses as the domain is the single clause args.
646+
case zip_compatible?(args_types, expected) do
647+
true -> {:ok, remote_return(return, args_types, stack)}
648+
false -> {:error, expected, clauses}
649+
end
650+
end
634651

635-
cond do
636-
returns == [] -> {:error, clauses}
637-
stack.mode == :static -> {:ok, Enum.reduce(returns, &union/2)}
638-
true -> {:ok, dynamic(Enum.reduce(returns, &union/2))}
639-
end
652+
defp apply_remote({:strong, domain, clauses}, args_types, stack) do
653+
# If the type is only gradual, the compatibility check is the same
654+
# as a non disjoint check. So we skip checking compatibility twice.
655+
with true <- zip_compatible_or_only_gradual?(args_types, domain),
656+
[_ | _] = returns <-
657+
for({expected, return} <- clauses, zip_not_disjoint?(args_types, expected), do: return) do
658+
{:ok, returns |> Enum.reduce(&union/2) |> remote_return(args_types, stack)}
640659
else
641-
Enum.find_value(clauses, {:error, clauses}, fn {expected, return} ->
642-
if zip_subtype?(args_types, expected) do
643-
{:ok, return}
644-
end
645-
end)
660+
_ -> {:error, domain, clauses}
646661
end
647662
end
648663

649-
defp zip_subtype?([actual | actuals], [expected | expecteds]) do
650-
subtype?(actual, expected) and zip_subtype?(actuals, expecteds)
664+
defp zip_compatible_or_only_gradual?([actual | actuals], [expected | expecteds]) do
665+
(only_gradual?(actual) or compatible?(actual, expected)) and
666+
zip_compatible_or_only_gradual?(actuals, expecteds)
651667
end
652668

653-
defp zip_subtype?([], []), do: true
669+
defp zip_compatible_or_only_gradual?([], []), do: true
654670

655671
defp zip_compatible?([actual | actuals], [expected | expecteds]) do
656672
compatible?(actual, expected) and zip_compatible?(actuals, expecteds)
657673
end
658674

659675
defp zip_compatible?([], []), do: true
660676

677+
defp zip_not_disjoint?([actual | actuals], [expected | expecteds]) do
678+
not disjoint?(actual, expected) and zip_not_disjoint?(actuals, expecteds)
679+
end
680+
681+
defp zip_not_disjoint?([], []), do: true
682+
661683
defp check_export(module, fun, arity, meta, stack, context) do
662684
case ParallelChecker.fetch_export(stack.cache, module, fun, arity) do
663685
{:ok, mode, :def, reason} ->
@@ -905,7 +927,7 @@ defmodule Module.Types.Of do
905927
}
906928
end
907929

908-
def format_diagnostic({:badapply, expr, args_types, clauses, context}) do
930+
def format_diagnostic({:badapply, expr, args_types, domain, clauses, context}) do
909931
traces = collect_traces(expr, context)
910932
{{:., _, [mod, fun]}, _, args} = expr
911933

@@ -920,10 +942,10 @@ defmodule Module.Types.Of do
920942
921943
given types:
922944
923-
#{args_to_quoted_string(mod, fun, args_types) |> indent(4)}
945+
#{args_to_quoted_string(mod, fun, args_types, domain) |> indent(4)}
924946
925947
but expected one of:
926-
#{clauses_args_to_quoted_string(mod, fun, clauses, args_types)}
948+
#{clauses_args_to_quoted_string(mod, fun, clauses)}
927949
""",
928950
format_traces(traces)
929951
])
@@ -1050,9 +1072,9 @@ defmodule Module.Types.Of do
10501072
match?({{:., _, [var, _fun]}, _, _args} when is_var(var), expr)
10511073
end
10521074

1053-
defp badapply_error({{:., _, [mod, fun]}, meta, _} = expr, args_types, stack, context) do
1054-
clauses = remote(mod, fun, length(args_types))
1055-
error({:badapply, expr, args_types, clauses, context}, meta, stack, context)
1075+
defp to_badapply_error({{:., _, [mod, fun]}, meta, _} = expr, args_types, stack, context) do
1076+
{_type, domain, [{args, _} | _] = clauses} = remote(mod, fun, length(args_types))
1077+
error({:badapply, expr, args_types, domain || args, clauses, context}, meta, stack, context)
10561078
end
10571079

10581080
defp empty_if(condition, content) do
@@ -1072,24 +1094,34 @@ defmodule Module.Types.Of do
10721094

10731095
alias Inspect.Algebra, as: IA
10741096

1075-
defp clauses_args_to_quoted_string(mod, fun, [{args, _return}], args_types) do
1076-
"\n " <> (clause_args_to_quoted_string(mod, fun, args, args_types) |> indent(4))
1097+
defp clauses_args_to_quoted_string(mod, fun, [{args, _return}]) do
1098+
"\n " <> (clause_args_to_quoted_string(mod, fun, args) |> indent(4))
10771099
end
10781100

1079-
defp clauses_args_to_quoted_string(mod, fun, clauses, args_types) do
1101+
defp clauses_args_to_quoted_string(mod, fun, clauses) do
10801102
clauses
10811103
|> Enum.with_index(fn {args, _return}, index ->
1082-
"\n##{index + 1}\n#{clause_args_to_quoted_string(mod, fun, args, args_types)}" |> indent(4)
1104+
"""
1105+
1106+
##{index + 1}
1107+
#{clause_args_to_quoted_string(mod, fun, args)}\
1108+
"""
1109+
|> indent(4)
10831110
end)
10841111
|> Enum.join("\n")
10851112
end
10861113

1087-
defp clause_args_to_quoted_string(mod, fun, args, args_types) do
1114+
defp clause_args_to_quoted_string(mod, fun, args) do
1115+
docs = Enum.map(args, &(&1 |> to_quoted() |> Code.Formatter.to_algebra()))
1116+
args_docs_to_quoted_string(mod, fun, docs)
1117+
end
1118+
1119+
defp args_to_quoted_string(mod, fun, args_types, domain) do
10881120
ansi? = IO.ANSI.enabled?()
10891121

10901122
docs =
1091-
Enum.zip_with(args_types, args, fn actual, expected ->
1092-
doc = expected |> to_quoted() |> Code.Formatter.to_algebra()
1123+
Enum.zip_with(args_types, domain, fn actual, expected ->
1124+
doc = actual |> to_quoted() |> Code.Formatter.to_algebra()
10931125

10941126
cond do
10951127
compatible?(actual, expected) -> doc
@@ -1101,11 +1133,6 @@ defmodule Module.Types.Of do
11011133
args_docs_to_quoted_string(mod, fun, docs)
11021134
end
11031135

1104-
defp args_to_quoted_string(mod, fun, args) do
1105-
docs = Enum.map(args, &(&1 |> to_quoted() |> Code.Formatter.to_algebra()))
1106-
args_docs_to_quoted_string(mod, fun, docs)
1107-
end
1108-
11091136
defp args_docs_to_quoted_string(mod, fun, docs) do
11101137
{_mod, _fun, docs} = :elixir_rewrite.erl_to_ex(mod, fun, docs)
11111138
doc = IA.fold(docs, fn doc, acc -> IA.glue(IA.concat(doc, ","), acc) end)

lib/elixir/test/elixir/module/types/expr_test.exs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,11 @@ defmodule Module.Types.ExprTest do
673673
end
674674

675675
describe ":erlang rewrites" do
676+
test "Kernel.not/1" do
677+
assert typecheck!([x], not is_list(x)) == boolean()
678+
assert typedyn!([x], not is_list(x)) == dynamic(boolean())
679+
end
680+
676681
test "Kernel.+/2" do
677682
assert typeerror!([x = :foo, y = 123], x + y) |> strip_ansi() ==
678683
~l"""
@@ -701,13 +706,13 @@ defmodule Module.Types.ExprTest do
701706
where "x" was given the type:
702707
703708
# type: dynamic(:foo)
704-
# from: types_test.ex:677
709+
# from: types_test.ex:LINE-1
705710
x = :foo
706711
707712
where "y" was given the type:
708713
709714
# type: integer()
710-
# from: types_test.ex:677
715+
# from: types_test.ex:LINE-1
711716
y = 123
712717
"""
713718
end

0 commit comments

Comments
 (0)