Skip to content

Commit 503c28d

Browse files
committed
Optimize disjoint to abort as soon as possible
1 parent b068710 commit 503c28d

File tree

4 files changed

+121
-33
lines changed

4 files changed

+121
-33
lines changed

lib/elixir/lib/module/types/descr.ex

Lines changed: 87 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ defmodule Module.Types.Descr do
123123
## Set operations
124124

125125
def term_type?(:term), do: true
126-
def term_type?(descr), do: subtype_static(unfolded_term(), Map.delete(descr, :dynamic))
126+
def term_type?(descr), do: subtype_static?(unfolded_term(), Map.delete(descr, :dynamic))
127127

128128
def dynamic_term_type?(descr), do: descr == %{dynamic: :term}
129129

@@ -241,9 +241,28 @@ defmodule Module.Types.Descr do
241241
defp difference_static(left, :term) when not is_optional_static(left), do: none()
242242

243243
defp difference_static(left, right) do
244-
iterator_difference(:maps.next(:maps.iterator(unfold(right))), unfold(left))
244+
iterator_difference_static(:maps.next(:maps.iterator(unfold(right))), unfold(left))
245245
end
246246

247+
defp iterator_difference_static({key, v2, iterator}, map) do
248+
acc =
249+
case map do
250+
%{^key => v1} ->
251+
case difference(key, v1, v2) do
252+
0 -> Map.delete(map, key)
253+
[] -> Map.delete(map, key)
254+
value -> %{map | key => value}
255+
end
256+
257+
%{} ->
258+
map
259+
end
260+
261+
iterator_difference_static(:maps.next(iterator), acc)
262+
end
263+
264+
defp iterator_difference_static(:none, map), do: map
265+
247266
# Returning 0 from the callback is taken as none() for that subtype.
248267
@compile {:inline, difference: 3}
249268
defp difference(:bitmap, v1, v2), do: bitmap_difference(v1, v2)
@@ -347,19 +366,19 @@ defmodule Module.Types.Descr do
347366
cond do
348367
is_grad_left and not is_grad_right ->
349368
left_dynamic = Map.get(left, :dynamic)
350-
subtype_static(left_dynamic, right)
369+
subtype_static?(left_dynamic, right)
351370

352371
is_grad_right and not is_grad_left ->
353372
right_static = Map.delete(right, :dynamic)
354-
subtype_static(left, right_static)
373+
subtype_static?(left, right_static)
355374

356375
true ->
357-
subtype_static(left, right)
376+
subtype_static?(left, right)
358377
end
359378
end
360379

361-
defp subtype_static(same, same), do: true
362-
defp subtype_static(left, right), do: empty?(difference_static(left, right))
380+
defp subtype_static?(same, same), do: true
381+
defp subtype_static?(left, right), do: empty?(difference_static(left, right))
363382

364383
@doc """
365384
Check if a type is equal to another.
@@ -372,8 +391,33 @@ defmodule Module.Types.Descr do
372391

373392
@doc """
374393
Check if two types are disjoint.
394+
395+
This reimplements intersection/2 but aborts as it finds a disjoint part.
375396
"""
376-
def disjoint?(left, right), do: empty?(intersection(left, right))
397+
def disjoint?(:term, other) when not is_optional(other), do: empty?(other)
398+
def disjoint?(other, :term) when not is_optional(other), do: empty?(other)
399+
def disjoint?(%{dynamic: :term}, other) when not is_optional(other), do: empty?(other)
400+
def disjoint?(other, %{dynamic: :term}) when not is_optional(other), do: empty?(other)
401+
402+
def disjoint?(left, right) do
403+
left = unfold(left)
404+
right = unfold(right)
405+
is_gradual_left = gradual?(left)
406+
is_gradual_right = gradual?(right)
407+
408+
cond do
409+
is_gradual_left and not is_gradual_right ->
410+
right_with_dynamic = Map.put(right, :dynamic, right)
411+
not non_disjoint_intersection?(left, right_with_dynamic)
412+
413+
is_gradual_right and not is_gradual_left ->
414+
left_with_dynamic = Map.put(left, :dynamic, left)
415+
not non_disjoint_intersection?(left_with_dynamic, right)
416+
417+
true ->
418+
not non_disjoint_intersection?(left, right)
419+
end
420+
end
377421

378422
@doc """
379423
Checks if a type is a compatible subtype of another.
@@ -400,9 +444,13 @@ defmodule Module.Types.Descr do
400444
right_dynamic = Map.get(right, :dynamic, right)
401445

402446
if empty?(left_static) do
403-
not disjoint?(left_dynamic, right_dynamic)
447+
cond do
448+
left_dynamic == :term -> not empty?(right_dynamic)
449+
right_dynamic == :term -> not empty?(left_dynamic)
450+
true -> non_disjoint_intersection?(left_dynamic, right_dynamic)
451+
end
404452
else
405-
subtype_static(left_static, right_dynamic)
453+
subtype_static?(left_static, right_dynamic)
406454
end
407455
end
408456

@@ -1610,18 +1658,23 @@ defmodule Module.Types.Descr do
16101658
m = length(elements2)
16111659

16121660
cond do
1613-
(tag1 == :closed and n < m) or (tag2 == :closed and n > m) -> throw(:empty)
1614-
tag1 == :open and tag2 == :open -> {:open, zip_intersection(elements1, elements2, [])}
1615-
true -> {:closed, zip_intersection(elements1, elements2, [])}
1661+
(tag1 == :closed and n < m) or (tag2 == :closed and n > m) ->
1662+
throw(:empty)
1663+
1664+
tag1 == :open and tag2 == :open ->
1665+
{:open, zip_non_empty_intersection!(elements1, elements2, [])}
1666+
1667+
true ->
1668+
{:closed, zip_non_empty_intersection!(elements1, elements2, [])}
16161669
end
16171670
end
16181671

16191672
# Intersects two lists of types, and _appends_ the extra elements to the result.
1620-
defp zip_intersection([], types2, acc), do: Enum.reverse(acc, types2)
1621-
defp zip_intersection(types1, [], acc), do: Enum.reverse(acc, types1)
1673+
defp zip_non_empty_intersection!([], types2, acc), do: Enum.reverse(acc, types2)
1674+
defp zip_non_empty_intersection!(types1, [], acc), do: Enum.reverse(acc, types1)
16221675

1623-
defp zip_intersection([type1 | rest1], [type2 | rest2], acc) do
1624-
zip_intersection(rest1, rest2, [non_empty_intersection!(type1, type2) | acc])
1676+
defp zip_non_empty_intersection!([type1 | rest1], [type2 | rest2], acc) do
1677+
zip_non_empty_intersection!(rest1, rest2, [non_empty_intersection!(type1, type2) | acc])
16251678
end
16261679

16271680
defp tuple_difference(dnf1, dnf2) do
@@ -2207,22 +2260,24 @@ defmodule Module.Types.Descr do
22072260

22082261
defp iterator_intersection(:none, _map, acc, _fun), do: :maps.from_list(acc)
22092262

2210-
defp iterator_difference({key, v2, iterator}, map) do
2211-
acc =
2212-
case map do
2213-
%{^key => v1} ->
2214-
case difference(key, v1, v2) do
2215-
0 -> Map.delete(map, key)
2216-
[] -> Map.delete(map, key)
2217-
value -> %{map | key => value}
2218-
end
2219-
2220-
%{} ->
2221-
map
2222-
end
2263+
defp non_disjoint_intersection?(left, right) do
2264+
# Erlang maps:intersect_with/3 has to preserve the order in combiner.
2265+
# We don't care about the order, so we have a faster implementation.
2266+
if map_size(left) > map_size(right) do
2267+
iterator_non_disjoint_intersection?(:maps.next(:maps.iterator(right)), left)
2268+
else
2269+
iterator_non_disjoint_intersection?(:maps.next(:maps.iterator(left)), right)
2270+
end
2271+
end
22232272

2224-
iterator_difference(:maps.next(iterator), acc)
2273+
defp iterator_non_disjoint_intersection?({key, v1, iterator}, map) do
2274+
with %{^key => v2} <- map,
2275+
value when value != 0 and value != @none <- intersection(key, v1, v2) do
2276+
true
2277+
else
2278+
_ -> iterator_non_disjoint_intersection?(:maps.next(iterator), map)
2279+
end
22252280
end
22262281

2227-
defp iterator_difference(:none, map), do: map
2282+
defp iterator_non_disjoint_intersection?(:none, _map), do: false
22282283
end

lib/elixir/lib/module/types/of.ex

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,11 +323,17 @@ defmodule Module.Types.Of do
323323
{:binary, :copy, [{[binary(), integer()], binary()}]},
324324
{:erlang, :atom_to_binary, [{[atom()], binary()}]},
325325
{:erlang, :atom_to_list, [{[atom()], list(integer())}]},
326+
{:erlang, :band, [{[integer(), integer()], integer()}]},
326327
{:erlang, :binary_to_atom, [{[binary()], atom()}]},
327328
{:erlang, :binary_to_existing_atom, [{[binary()], atom()}]},
328329
{:erlang, :binary_to_integer, [{[binary()], integer()}]},
329330
{:erlang, :binary_to_integer, [{[binary(), integer()], integer()}]},
330331
{:erlang, :binary_to_float, [{[binary()], float()}]},
332+
{:erlang, :bnot, [{[integer()], integer()}]},
333+
{:erlang, :bor, [{[integer(), integer()], integer()}]},
334+
{:erlang, :bsl, [{[integer(), integer()], integer()}]},
335+
{:erlang, :bsr, [{[integer(), integer()], integer()}]},
336+
{:erlang, :bxor, [{[integer(), integer()], integer()}]},
331337
{:erlang, :integer_to_binary, [{[integer()], binary()}]},
332338
{:erlang, :integer_to_binary, [{[integer(), integer()], binary()}]},
333339
{:erlang, :integer_to_list, [{[integer()], non_empty_list(integer())}]},

lib/elixir/test/elixir/module/types/descr_test.exs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,10 @@ defmodule Module.Types.DescrTest do
10031003
end
10041004

10051005
describe "disjoint" do
1006+
test "optional" do
1007+
assert disjoint?(term(), if_set(none()))
1008+
end
1009+
10061010
test "map" do
10071011
refute disjoint?(open_map(), open_map(a: integer()))
10081012
end

lib/elixir/test/elixir/module/types/expr_test.exs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ defmodule Module.Types.ExprTest do
666666
end
667667
end
668668

669-
describe "apply" do
669+
describe ":erlang rewrites" do
670670
test "Integer.to_string/1" do
671671
assert typeerror!([x = :foo], Integer.to_string(x)) ==
672672
~l"""
@@ -689,6 +689,29 @@ defmodule Module.Types.ExprTest do
689689
x = :foo
690690
"""
691691
end
692+
693+
test "Bitwise.bnot/1" do
694+
assert typeerror!([x = :foo], Bitwise.bnot(x)) ==
695+
~l"""
696+
incompatible types given to Bitwise.bnot/1:
697+
698+
Bitwise.bnot(x)
699+
700+
expected types:
701+
702+
integer()
703+
704+
but got types:
705+
706+
dynamic(:foo)
707+
708+
where "x" was given the type:
709+
710+
# type: dynamic(:foo)
711+
# from: types_test.ex:LINE-1
712+
x = :foo
713+
"""
714+
end
692715
end
693716

694717
describe "try" do

0 commit comments

Comments
 (0)