Skip to content

Commit c1c2cf8

Browse files
authored
Type checking of protocol dispatch (#14117)
1 parent c34627c commit c1c2cf8

File tree

24 files changed

+713
-273
lines changed

24 files changed

+713
-273
lines changed

lib/elixir/lib/enum.ex

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5001,3 +5001,69 @@ defimpl Enumerable, for: Function do
50015001
description: "only anonymous functions of arity 2 are enumerable"
50025002
end
50035003
end
5004+
5005+
defimpl Enumerable, for: Range do
5006+
def reduce(first..last//step, acc, fun) do
5007+
reduce(first, last, acc, fun, step)
5008+
end
5009+
5010+
# TODO: Remove me on v2.0
5011+
def reduce(%{__struct__: Range, first: first, last: last} = range, acc, fun) do
5012+
step = if first <= last, do: 1, else: -1
5013+
reduce(Map.put(range, :step, step), acc, fun)
5014+
end
5015+
5016+
defp reduce(_first, _last, {:halt, acc}, _fun, _step) do
5017+
{:halted, acc}
5018+
end
5019+
5020+
defp reduce(first, last, {:suspend, acc}, fun, step) do
5021+
{:suspended, acc, &reduce(first, last, &1, fun, step)}
5022+
end
5023+
5024+
defp reduce(first, last, {:cont, acc}, fun, step)
5025+
when step > 0 and first <= last
5026+
when step < 0 and first >= last do
5027+
reduce(first + step, last, fun.(first, acc), fun, step)
5028+
end
5029+
5030+
defp reduce(_, _, {:cont, acc}, _fun, _up) do
5031+
{:done, acc}
5032+
end
5033+
5034+
def member?(first..last//step, value) when is_integer(value) do
5035+
if step > 0 do
5036+
{:ok, first <= value and value <= last and rem(value - first, step) == 0}
5037+
else
5038+
{:ok, last <= value and value <= first and rem(value - first, step) == 0}
5039+
end
5040+
end
5041+
5042+
# TODO: Remove me on v2.0
5043+
def member?(%{__struct__: Range, first: first, last: last} = range, value)
5044+
when is_integer(value) do
5045+
step = if first <= last, do: 1, else: -1
5046+
member?(Map.put(range, :step, step), value)
5047+
end
5048+
5049+
def member?(_, _value) do
5050+
{:ok, false}
5051+
end
5052+
5053+
def count(range) do
5054+
{:ok, Range.size(range)}
5055+
end
5056+
5057+
def slice(first.._//step = range) do
5058+
{:ok, Range.size(range), &slice(first + &1 * step, step + &3 - 1, &2)}
5059+
end
5060+
5061+
# TODO: Remove me on v2.0
5062+
def slice(%{__struct__: Range, first: first, last: last} = range) do
5063+
step = if first <= last, do: 1, else: -1
5064+
slice(Map.put(range, :step, step))
5065+
end
5066+
5067+
defp slice(_current, _step, 0), do: []
5068+
defp slice(current, step, remaining), do: [current | slice(current + step, step, remaining - 1)]
5069+
end

lib/elixir/lib/exception.ex

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,7 +2003,7 @@ defmodule Protocol.UndefinedError do
20032003
# Indent only lines with contents on them
20042004
|> String.replace(~r/^(?=.+)/m, " ")
20052005

2006-
"protocol #{inspect(protocol)} not implemented for type " <>
2006+
"protocol #{inspect(protocol)} not implemented for " <>
20072007
value_type(value) <>
20082008
maybe_description(description) <>
20092009
maybe_available(protocol) <>
@@ -2038,7 +2038,7 @@ defmodule Protocol.UndefinedError do
20382038
". There are no implementations for this protocol."
20392039

20402040
{:consolidated, types} ->
2041-
". This protocol is implemented for the following type(s): " <>
2041+
". This protocol is implemented for: " <>
20422042
Enum.map_join(types, ", ", &inspect/1)
20432043

20442044
:not_consolidated ->

lib/elixir/lib/inspect.ex

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,13 @@ defprotocol Inspect do
200200
do: var!(info)
201201

202202
var!(name) = Macro.inspect_atom(:literal, unquote(module))
203-
unquote(inspect_module).inspect(var!(struct), var!(name), var!(infos), var!(opts))
203+
204+
unquote(inspect_module).inspect_as_struct(
205+
var!(struct),
206+
var!(name),
207+
var!(infos),
208+
var!(opts)
209+
)
204210
end
205211
end
206212
end
@@ -390,6 +396,10 @@ end
390396

391397
defimpl Inspect, for: Map do
392398
def inspect(map, opts) do
399+
inspect_as_map(map, opts)
400+
end
401+
402+
def inspect_as_map(map, opts) do
393403
list =
394404
if Keyword.get(opts.custom_options, :sort_maps) do
395405
map |> Map.to_list() |> :lists.sort()
@@ -408,7 +418,7 @@ defimpl Inspect, for: Map do
408418
map_container_doc(list, "", opts, fun)
409419
end
410420

411-
def inspect(map, name, infos, opts) do
421+
def inspect_as_struct(map, name, infos, opts) do
412422
fun = fn %{field: field}, opts -> Inspect.List.keyword({field, Map.get(map, field)}, opts) end
413423
map_container_doc(infos, name, opts, fun)
414424
end
@@ -599,25 +609,36 @@ end
599609
defimpl Inspect, for: Any do
600610
def inspect(%module{} = struct, opts) do
601611
try do
602-
{module.__struct__(), module.__info__(:struct)}
612+
module.__info__(:struct)
603613
rescue
604-
_ -> Inspect.Map.inspect(struct, opts)
614+
_ -> Inspect.Map.inspect_as_map(struct, opts)
605615
else
606-
{dunder, fields} ->
607-
if Map.keys(dunder) == Map.keys(struct) do
608-
infos =
609-
for %{field: field} = info <- fields,
610-
field not in [:__struct__, :__exception__],
611-
do: info
612-
613-
Inspect.Map.inspect(struct, Macro.inspect_atom(:literal, module), infos, opts)
616+
info ->
617+
if valid_struct?(info, struct) do
618+
info =
619+
for %{field: field} = map <- info,
620+
field != :__exception__,
621+
do: map
622+
623+
Inspect.Map.inspect_as_struct(struct, Macro.inspect_atom(:literal, module), info, opts)
614624
else
615-
Inspect.Map.inspect(struct, opts)
625+
Inspect.Map.inspect_as_map(struct, opts)
616626
end
617627
end
618628
end
619629

620-
def inspect(map, name, infos, opts) do
630+
defp valid_struct?(info, struct), do: valid_struct?(info, struct, map_size(struct) - 1)
631+
632+
defp valid_struct?([%{field: field} | info], struct, count) when is_map_key(struct, field),
633+
do: valid_struct?(info, struct, count - 1)
634+
635+
defp valid_struct?([], _struct, 0),
636+
do: true
637+
638+
defp valid_struct?(_fields, _struct, _count),
639+
do: false
640+
641+
def inspect_as_struct(map, name, infos, opts) do
621642
open = color_doc("#" <> name <> "<", :map, opts)
622643
sep = color_doc(",", :map, opts)
623644
close = color_doc(">", :map, opts)
@@ -631,6 +652,25 @@ defimpl Inspect, for: Any do
631652
end
632653
end
633654

655+
defimpl Inspect, for: Range do
656+
import Inspect.Algebra
657+
import Kernel, except: [inspect: 2]
658+
659+
def inspect(first..last//1, opts) when last >= first do
660+
concat([to_doc(first, opts), "..", to_doc(last, opts)])
661+
end
662+
663+
def inspect(first..last//step, opts) do
664+
concat([to_doc(first, opts), "..", to_doc(last, opts), "//", to_doc(step, opts)])
665+
end
666+
667+
# TODO: Remove me on v2.0
668+
def inspect(%{__struct__: Range, first: first, last: last} = range, opts) do
669+
step = if first <= last, do: 1, else: -1
670+
inspect(Map.put(range, :step, step), opts)
671+
end
672+
end
673+
634674
require Protocol
635675

636676
Protocol.derive(

lib/elixir/lib/inspect/algebra.ex

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,14 +355,14 @@ defmodule Inspect.Algebra do
355355
# we won't try to render any failed instruct when building
356356
# the error message.
357357
if Process.get(:inspect_trap) do
358-
Inspect.Map.inspect(struct, opts)
358+
Inspect.Map.inspect_as_map(struct, opts)
359359
else
360360
try do
361361
Process.put(:inspect_trap, true)
362362

363363
inspected_struct =
364364
struct
365-
|> Inspect.Map.inspect(%{
365+
|> Inspect.Map.inspect_as_map(%{
366366
opts
367367
| syntax_colors: [],
368368
inspect_fun: Inspect.Opts.default_inspect_fun()
@@ -389,7 +389,7 @@ defmodule Inspect.Algebra do
389389
end
390390
end
391391
else
392-
Inspect.Map.inspect(struct, opts)
392+
Inspect.Map.inspect_as_map(struct, opts)
393393
end
394394
end
395395

lib/elixir/lib/kernel.ex

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3409,7 +3409,7 @@ defmodule Kernel do
34093409
34103410
"""
34113411
defmacro to_charlist(term) do
3412-
quote(do: List.Chars.to_charlist(unquote(term)))
3412+
quote(do: :"Elixir.List.Chars".to_charlist(unquote(term)))
34133413
end
34143414

34153415
@doc """
@@ -4064,7 +4064,7 @@ defmodule Kernel do
40644064
-1
40654065
end
40664066

4067-
{:%{}, [], [__struct__: Elixir.Range, first: first, last: last, step: step]}
4067+
{:%, [], [Elixir.Range, {:%{}, [], [first: first, last: last, step: step]}]}
40684068
end
40694069

40704070
defp stepless_range(nil, first, last, _caller) do
@@ -4090,7 +4090,7 @@ defmodule Kernel do
40904090
Macro.Env.stacktrace(caller)
40914091
)
40924092

4093-
{:%{}, [], [__struct__: Elixir.Range, first: first, last: last, step: step]}
4093+
{:%, [], [Elixir.Range, {:%{}, [], [first: first, last: last, step: step]}]}
40944094
end
40954095

40964096
defp stepless_range(:match, first, last, caller) do
@@ -4103,7 +4103,7 @@ defmodule Kernel do
41034103
Macro.Env.stacktrace(caller)
41044104
)
41054105

4106-
{:%{}, [], [__struct__: Elixir.Range, first: first, last: last]}
4106+
{:%, [], [Elixir.Range, {:%{}, [], [first: first, last: last]}]}
41074107
end
41084108

41094109
@doc """
@@ -4142,14 +4142,14 @@ defmodule Kernel do
41424142
range(__CALLER__.context, first, last, step)
41434143

41444144
false ->
4145-
range(__CALLER__.context, first, last, step)
4145+
{:%{}, [], [__struct__: Elixir.Range, first: first, last: last, step: step]}
41464146
end
41474147
end
41484148

41494149
defp range(context, first, last, step)
41504150
when is_integer(first) and is_integer(last) and is_integer(step)
41514151
when context != nil do
4152-
{:%{}, [], [__struct__: Elixir.Range, first: first, last: last, step: step]}
4152+
{:%, [], [Elixir.Range, {:%{}, [], [first: first, last: last, step: step]}]}
41534153
end
41544154

41554155
defp range(nil, first, last, step) do
@@ -4559,11 +4559,10 @@ defmodule Kernel do
45594559
raise ArgumentError, "found unescaped value on the right side of in/2: #{inspect(right)}"
45604560

45614561
right ->
4562-
with {:%{}, _meta, fields} <- right,
4563-
[__struct__: Elixir.Range, first: first, last: last, step: step] <-
4564-
:lists.usort(fields) do
4565-
in_var(in_body?, left, &in_range(&1, expand.(first), expand.(last), expand.(step)))
4566-
else
4562+
case range_fields(right) do
4563+
[first: first, last: last, step: step] ->
4564+
in_var(in_body?, left, &in_range(&1, expand.(first), expand.(last), expand.(step)))
4565+
45674566
_ when in_body? ->
45684567
quote(do: Elixir.Enum.member?(unquote(right), unquote(left)))
45694568

@@ -4573,6 +4572,10 @@ defmodule Kernel do
45734572
end
45744573
end
45754574

4575+
defp range_fields({:%, _, [Elixir.Range, {:%{}, _, fields}]}), do: :lists.usort(fields)
4576+
defp range_fields({:%{}, _, [__struct__: Elixir.Range] ++ fields}), do: :lists.usort(fields)
4577+
defp range_fields(_), do: []
4578+
45764579
defp raise_on_invalid_args_in_2(right) do
45774580
raise ArgumentError, <<
45784581
"invalid right argument for operator \"in\", it expects a compile-time proper list ",
@@ -5385,7 +5388,7 @@ defmodule Kernel do
53855388
53865389
john = %User{name: "John"}
53875390
MyProtocol.call(john)
5388-
** (Protocol.UndefinedError) protocol MyProtocol not implemented for %User{...}
5391+
** (Protocol.UndefinedError) protocol MyProtocol not implemented for User (a struct)
53895392
53905393
`defstruct/1`, however, allows protocol implementations to be
53915394
*derived*. This can be done by defining a `@derive` attribute as a

lib/elixir/lib/module/types.ex

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ defmodule Module.Types do
145145
defp default_domain({_, arity} = fun_arity, impl) do
146146
with {for, callbacks} <- impl,
147147
true <- fun_arity in callbacks do
148-
[Module.Types.Of.impl(for) | List.duplicate(Descr.dynamic(), arity - 1)]
148+
[Descr.dynamic(Module.Types.Of.impl(for)) | List.duplicate(Descr.dynamic(), arity - 1)]
149149
else
150150
_ -> List.duplicate(Descr.dynamic(), arity)
151151
end
@@ -282,7 +282,7 @@ defmodule Module.Types do
282282

283283
try do
284284
{args_types, context} =
285-
Pattern.of_head(args, guards, expected, :default, meta, stack, context)
285+
Pattern.of_head(args, guards, expected, {:infer, expected}, meta, stack, context)
286286

287287
{return_type, context} =
288288
Expr.of_expr(body, stack, context)

0 commit comments

Comments
 (0)