Skip to content

Commit 860f485

Browse files
authored
Inference of patterns (#13909)
This pull request adds inference of patterns, being able to refine types as new information is added. The next step is to use this inference in the compiler itself. Future work will add module-local inference of return types (most likely still in v1.18) and inference of guards (most likely in v1.19). In order to support inference of patterns, Elixir will raise if it finds recursive variable definitions. This means patterns that never match, such as this one, will no longer compile: def foo(x = {:ok, y}, x = y) However, recursion of root variables (where variables directly point to each other), will also fail to compile: def foo(x = y, y = z, z = x) While the definition above could succeed (as long as all three arguments are equal), there is a much cleaner version of writing the same code, that does not require solving cycles in our head: def foo(x, x, x)
1 parent 673fe4a commit 860f485

21 files changed

+1378
-701
lines changed

lib/elixir/lib/module/types.ex

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ defmodule Module.Types do
8484
# A list of all warnings found so far
8585
warnings: [],
8686
# Information about all vars and their types
87-
vars: %{}
87+
vars: %{},
88+
# Information about variables and arguments from patterns
89+
pattern_info: nil
8890
}
8991
end
9092
end

lib/elixir/lib/module/types/descr.ex

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ defmodule Module.Types.Descr do
3838

3939
# Type definitions
4040

41+
defguard is_descr(descr) when is_map(descr) or descr == :term
42+
4143
def dynamic(), do: %{dynamic: :term}
4244
def none(), do: @none
4345
def term(), do: :term
@@ -55,8 +57,8 @@ defmodule Module.Types.Descr do
5557
def integer(), do: %{bitmap: @bit_integer}
5658
def float(), do: %{bitmap: @bit_float}
5759
def fun(), do: %{bitmap: @bit_fun}
58-
def list(), do: %{bitmap: @bit_list}
59-
def non_empty_list(), do: %{bitmap: @bit_non_empty_list}
60+
def list(_arg), do: %{bitmap: @bit_list}
61+
def non_empty_list(_arg, _tail \\ empty_list()), do: %{bitmap: @bit_non_empty_list}
6062
def open_map(), do: %{map: @map_top}
6163
def open_map(pairs), do: map_descr(:open, pairs)
6264
def open_tuple(elements), do: tuple_descr(:open, elements)
@@ -113,6 +115,8 @@ defmodule Module.Types.Descr do
113115
def term_type?(:term), do: true
114116
def term_type?(descr), do: subtype_static(unfolded_term(), Map.delete(descr, :dynamic))
115117

118+
def dynamic_term_type?(descr), do: descr == %{dynamic: :term}
119+
116120
def gradual?(:term), do: false
117121
def gradual?(descr), do: is_map_key(descr, :dynamic)
118122

@@ -133,6 +137,8 @@ defmodule Module.Types.Descr do
133137
"""
134138
def union(:term, other) when not is_optional(other), do: :term
135139
def union(other, :term) when not is_optional(other), do: :term
140+
def union(none, other) when none == %{}, do: other
141+
def union(other, none) when none == %{}, do: other
136142

137143
def union(left, right) do
138144
left = unfold(left)
@@ -166,6 +172,8 @@ defmodule Module.Types.Descr do
166172
"""
167173
def intersection(:term, other) when not is_optional(other), do: other
168174
def intersection(other, :term) when not is_optional(other), do: other
175+
def intersection(%{dynamic: :term}, other) when not is_optional(other), do: dynamic(other)
176+
def intersection(other, %{dynamic: :term}) when not is_optional(other), do: dynamic(other)
169177

170178
def intersection(left, right) do
171179
left = unfold(left)
@@ -385,6 +393,18 @@ defmodule Module.Types.Descr do
385393

386394
## Bitmaps
387395

396+
@doc """
397+
Optimized version of `not empty?(intersection(empty_list(), type))`.
398+
"""
399+
def empty_list_type?(:term), do: true
400+
def empty_list_type?(%{dynamic: :term}), do: true
401+
402+
def empty_list_type?(%{dynamic: %{bitmap: bitmap}}) when (bitmap &&& @bit_empty_list) != 0,
403+
do: true
404+
405+
def empty_list_type?(%{bitmap: bitmap}) when (bitmap &&& @bit_empty_list) != 0, do: true
406+
def empty_list_type?(_), do: false
407+
388408
@doc """
389409
Optimized version of `not empty?(intersection(binary(), type))`.
390410
"""

lib/elixir/lib/module/types/expr.ex

Lines changed: 44 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -6,69 +6,63 @@ defmodule Module.Types.Expr do
66

77
14 = length(Macro.Env.__info__(:struct))
88

9+
aliases = list(tuple([atom(), atom()]))
10+
functions_and_macros = list(tuple([atom(), list(tuple([atom(), integer()]))]))
11+
list_of_modules = list(atom())
12+
913
@caller closed_map(
1014
__struct__: atom([Macro.Env]),
11-
aliases: list(),
15+
aliases: aliases,
1216
context: atom([:match, :guard, nil]),
13-
context_modules: list(),
17+
context_modules: list_of_modules,
1418
file: binary(),
1519
function: union(tuple(), atom([nil])),
16-
functions: list(),
20+
functions: functions_and_macros,
1721
lexical_tracker: union(pid(), atom([nil])),
1822
line: integer(),
19-
macro_aliases: list(),
20-
macros: list(),
23+
macro_aliases: aliases,
24+
macros: functions_and_macros,
2125
module: atom(),
22-
requires: list(),
23-
tracers: list(),
26+
requires: list_of_modules,
27+
tracers: list_of_modules,
2428
versioned_vars: open_map()
2529
)
2630

2731
@atom_true atom([true])
2832
@exception open_map(__struct__: atom(), __exception__: @atom_true)
2933

30-
# of_expr/4 is public as it is called recursively from Of.binary
31-
def of_expr(expr, expected_expr, stack, context) do
32-
with {:ok, actual, context} <- of_expr(expr, stack, context) do
33-
Of.intersect(actual, expected_expr, stack, context)
34-
end
35-
end
36-
3734
# :atom
38-
def of_expr(atom, _stack, context) when is_atom(atom) do
39-
{:ok, atom([atom]), context}
40-
end
35+
def of_expr(atom, _stack, context) when is_atom(atom),
36+
do: {:ok, atom([atom]), context}
4137

4238
# 12
43-
def of_expr(literal, _stack, context) when is_integer(literal) do
44-
{:ok, integer(), context}
45-
end
39+
def of_expr(literal, _stack, context) when is_integer(literal),
40+
do: {:ok, integer(), context}
4641

4742
# 1.2
48-
def of_expr(literal, _stack, context) when is_float(literal) do
49-
{:ok, float(), context}
50-
end
43+
def of_expr(literal, _stack, context) when is_float(literal),
44+
do: {:ok, float(), context}
5145

5246
# "..."
53-
def of_expr(literal, _stack, context) when is_binary(literal) do
54-
{:ok, binary(), context}
55-
end
47+
def of_expr(literal, _stack, context) when is_binary(literal),
48+
do: {:ok, binary(), context}
5649

5750
# #PID<...>
58-
def of_expr(literal, _stack, context) when is_pid(literal) do
59-
{:ok, pid(), context}
60-
end
51+
def of_expr(literal, _stack, context) when is_pid(literal),
52+
do: {:ok, pid(), context}
6153

6254
# []
63-
def of_expr([], _stack, context) do
64-
{:ok, empty_list(), context}
65-
end
55+
def of_expr([], _stack, context),
56+
do: {:ok, empty_list(), context}
6657

67-
# TODO: [expr, ...]
68-
def of_expr(exprs, stack, context) when is_list(exprs) do
69-
case map_reduce_ok(exprs, context, &of_expr(&1, stack, &2)) do
70-
{:ok, _types, context} -> {:ok, non_empty_list(), context}
71-
{:error, context} -> {:error, context}
58+
# [expr, ...]
59+
def of_expr(list, stack, context) when is_list(list) do
60+
{prefix, suffix} = unpack_list(list, [])
61+
62+
with {:ok, prefix, context} <-
63+
map_reduce_ok(prefix, context, &of_expr(&1, stack, &2)),
64+
{:ok, suffix, context} <- of_expr(suffix, stack, context) do
65+
{:ok, non_empty_list(Enum.reduce(prefix, &union/2), suffix), context}
7266
end
7367
end
7468

@@ -84,27 +78,11 @@ defmodule Module.Types.Expr do
8478
def of_expr({:<<>>, _meta, args}, stack, context) do
8579
case Of.binary(args, :expr, stack, context) do
8680
{:ok, context} -> {:ok, binary(), context}
87-
# It is safe to discard errors from binary inside expressions
81+
# It is safe to discard errors from binaries, we can continue typechecking
8882
{:error, context} -> {:ok, binary(), context}
8983
end
9084
end
9185

92-
# TODO: left | []
93-
def of_expr({:|, _meta, [left_expr, []]}, stack, context) do
94-
of_expr(left_expr, stack, context)
95-
end
96-
97-
# TODO: left | right
98-
def of_expr({:|, _meta, [left_expr, right_expr]}, stack, context) do
99-
case of_expr(left_expr, stack, context) do
100-
{:ok, _left, context} ->
101-
of_expr(right_expr, stack, context)
102-
103-
{:error, context} ->
104-
{:error, context}
105-
end
106-
end
107-
10886
def of_expr({:__CALLER__, _meta, var_context}, _stack, context)
10987
when is_atom(var_context) do
11088
{:ok, @caller, context}
@@ -113,7 +91,7 @@ defmodule Module.Types.Expr do
11391
# TODO: __STACKTRACE__
11492
def of_expr({:__STACKTRACE__, _meta, var_context}, _stack, context)
11593
when is_atom(var_context) do
116-
{:ok, list(), context}
94+
{:ok, list(term()), context}
11795
end
11896

11997
# {...}
@@ -123,10 +101,10 @@ defmodule Module.Types.Expr do
123101
end
124102
end
125103

126-
# TODO: left = right
104+
# left = right
127105
def of_expr({:=, _meta, [left_expr, right_expr]} = expr, stack, context) do
128106
with {:ok, right_type, context} <- of_expr(right_expr, stack, context) do
129-
Pattern.of_match(left_expr, {right_type, expr}, stack, context)
107+
Pattern.of_match(left_expr, right_type, expr, stack, context)
130108
end
131109
end
132110

@@ -152,6 +130,7 @@ defmodule Module.Types.Expr do
152130
{:ok, {key, type}, context}
153131
end
154132
end),
133+
# TODO: args_types could be an empty list
155134
{:ok, struct_type, context} <-
156135
Of.struct(module, args_types, :only_defaults, struct_meta, stack, context),
157136
{:ok, map_type, context} <- of_expr(map, stack, context) do
@@ -172,6 +151,7 @@ defmodule Module.Types.Expr do
172151

173152
# %Struct{}
174153
def of_expr({:%, _, [module, {:%{}, _, args}]} = expr, stack, context) do
154+
# TODO: We should not skip defaults
175155
Of.struct(expr, module, args, :skip_defaults, stack, context, &of_expr/3)
176156
end
177157

@@ -359,7 +339,7 @@ defmodule Module.Types.Expr do
359339
{:ok, fun(), context}
360340
end
361341

362-
# TODO: call(arg)
342+
# TODO: local_call(arg)
363343
def of_expr({fun, _meta, args}, stack, context)
364344
when is_atom(fun) and is_list(args) do
365345
with {:ok, _arg_types, context} <-
@@ -404,7 +384,7 @@ defmodule Module.Types.Expr do
404384
end
405385

406386
{:ok, _type, context} =
407-
Of.refine_var(var, {expected, expr}, formatter, stack, context)
387+
Of.refine_var(var, expected, expr, formatter, stack, context)
408388

409389
context
410390
end
@@ -426,7 +406,7 @@ defmodule Module.Types.Expr do
426406

427407
defp for_clause({:<<>>, _, [{:<-, meta, [left, right]}]}, stack, context) do
428408
with {:ok, right_type, context} <- of_expr(right, stack, context),
429-
{:ok, _pattern_type, context} <- Pattern.of_match(left, {binary(), left}, stack, context) do
409+
{:ok, _pattern_type, context} <- Pattern.of_match(left, binary(), left, stack, context) do
430410
if binary_type?(right_type) do
431411
{:ok, context}
432412
else
@@ -539,7 +519,7 @@ defmodule Module.Types.Expr do
539519
## Warning formatting
540520

541521
def format_diagnostic({:badupdate, type, expr, expected_type, actual_type, context}) do
542-
traces = Of.collect_traces(expr, context)
522+
traces = collect_traces(expr, context)
543523

544524
%{
545525
details: %{typing_traces: traces},
@@ -558,13 +538,13 @@ defmodule Module.Types.Expr do
558538
559539
#{to_quoted_string(actual_type) |> indent(4)}
560540
""",
561-
Of.format_traces(traces)
541+
format_traces(traces)
562542
])
563543
}
564544
end
565545

566546
def format_diagnostic({:badbinary, type, expr, context}) do
567-
traces = Of.collect_traces(expr, context)
547+
traces = collect_traces(expr, context)
568548

569549
%{
570550
details: %{typing_traces: traces},
@@ -579,7 +559,7 @@ defmodule Module.Types.Expr do
579559
580560
#{to_quoted_string(type) |> indent(4)}
581561
""",
582-
Of.format_traces(traces)
562+
format_traces(traces)
583563
])
584564
}
585565
end

0 commit comments

Comments
 (0)