Skip to content

Commit 0840c0c

Browse files
N5N3KristofferC
authored andcommitted
subtype: save some union stack space for ∃ free cases. (#58159)
and avoid eager UnionAll unwrapping to hit more fast path. close #58129 (test passed locally) close #56350 (MWE returns `Tuple{Any, Any, Vararg}` now.) (cherry picked from commit 334c316)
1 parent 16832c1 commit 0840c0c

File tree

2 files changed

+39
-9
lines changed

2 files changed

+39
-9
lines changed

src/subtype.c

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ static jl_varbinding_t *lookup(jl_stenv_t *e, jl_tvar_t *v) JL_GLOBALLY_ROOTED J
139139

140140
static int statestack_get(jl_unionstate_t *st, int i) JL_NOTSAFEPOINT
141141
{
142-
assert(i >= 0 && i <= 32767); // limited by the depth bit.
142+
assert(i >= 0 && i < 32767); // limited by the depth bit.
143143
// get the `i`th bit in an array of 32-bit words
144144
jl_bits_stack_t *stack = &st->stack;
145145
while (i >= sizeof(stack->data) * 8) {
@@ -153,7 +153,7 @@ static int statestack_get(jl_unionstate_t *st, int i) JL_NOTSAFEPOINT
153153

154154
static void statestack_set(jl_unionstate_t *st, int i, int val) JL_NOTSAFEPOINT
155155
{
156-
assert(i >= 0 && i <= 32767); // limited by the depth bit.
156+
assert(i >= 0 && i < 32767); // limited by the depth bit.
157157
jl_bits_stack_t *stack = &st->stack;
158158
while (i >= sizeof(stack->data) * 8) {
159159
if (__unlikely(stack->next == NULL)) {
@@ -1437,11 +1437,14 @@ static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
14371437
}
14381438
if (jl_is_unionall(y)) {
14391439
jl_varbinding_t *xb = lookup(e, (jl_tvar_t*)x);
1440-
if (xb == NULL ? !e->ignore_free : !xb->right) {
1440+
jl_value_t *xub = xb == NULL ? ((jl_tvar_t *)x)->ub : xb->ub;
1441+
if ((xb == NULL ? !e->ignore_free : !xb->right) && xub != y) {
14411442
// We'd better unwrap `y::UnionAll` eagerly if `x` isa ∀-var.
14421443
// This makes sure the following cases work correct:
14431444
// 1) `∀T <: Union{∃S, SomeType{P}} where {P}`: `S == Any` ==> `S >: T`
14441445
// 2) `∀T <: Union{∀T, SomeType{P}} where {P}`:
1446+
// note: if xub == y we'd better try `subtype_var` as `subtype_left_var`
1447+
// hit `==` based fast path.
14451448
return subtype_unionall(x, (jl_unionall_t*)y, e, 1, param);
14461449
}
14471450
}
@@ -1579,6 +1582,8 @@ static int has_exists_typevar(jl_value_t *x, jl_stenv_t *e) JL_NOTSAFEPOINT
15791582
return env != NULL && jl_has_bound_typevars(x, env);
15801583
}
15811584

1585+
static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param);
1586+
15821587
static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int limit_slow)
15831588
{
15841589
int16_t oldRmore = e->Runions.more;
@@ -1592,7 +1597,18 @@ static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t
15921597
return jl_subtype(x, y);
15931598
int has_exists = (!kindx && has_exists_typevar(x, e)) ||
15941599
(!kindy && has_exists_typevar(y, e));
1595-
if (has_exists && (is_exists_typevar(x, e) != is_exists_typevar(y, e))) {
1600+
if (!has_exists) {
1601+
// We can use ∀_∃_subtype safely for ∃ free inputs.
1602+
// This helps to save some bits in union stack.
1603+
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
1604+
e->Lunions.used = e->Runions.used = 0;
1605+
e->Lunions.depth = e->Runions.depth = 0;
1606+
e->Lunions.more = e->Runions.more = 0;
1607+
sub = forall_exists_subtype(x, y, e, param);
1608+
pop_unionstate(&e->Runions, &oldRunions);
1609+
return sub;
1610+
}
1611+
if (is_exists_typevar(x, e) != is_exists_typevar(y, e)) {
15961612
e->Lunions.used = 0;
15971613
while (1) {
15981614
e->Lunions.more = 0;
@@ -1606,7 +1622,7 @@ static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t
16061622
if (limit_slow == -1)
16071623
limit_slow = kindx || kindy;
16081624
jl_savedenv_t se;
1609-
save_env(e, &se, has_exists);
1625+
save_env(e, &se, 1);
16101626
int count, limited = 0, ini_count = 0;
16111627
jl_saved_unionstate_t latestLunions = {0, 0, 0, NULL};
16121628
while (1) {
@@ -1624,26 +1640,26 @@ static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t
16241640
limited = 1;
16251641
if (!sub || !next_union_state(e, 0))
16261642
break;
1627-
if (limited || !has_exists || e->Runions.more == oldRmore) {
1643+
if (limited || e->Runions.more == oldRmore) {
16281644
// re-save env and freeze the ∃decision for previous ∀Union
16291645
// Note: We could ignore the rest `∃Union` decisions if `x` and `y`
16301646
// contain no ∃ typevar, as they have no effect on env.
16311647
ini_count = count;
16321648
push_unionstate(&latestLunions, &e->Lunions);
1633-
re_save_env(e, &se, has_exists);
1649+
re_save_env(e, &se, 1);
16341650
e->Runions.more = oldRmore;
16351651
}
16361652
}
16371653
if (sub || e->Runions.more == oldRmore)
16381654
break;
16391655
assert(e->Runions.more > oldRmore);
16401656
next_union_state(e, 1);
1641-
restore_env(e, &se, has_exists); // also restore Rdepth here
1657+
restore_env(e, &se, 1); // also restore Rdepth here
16421658
e->Runions.more = oldRmore;
16431659
}
16441660
if (!sub)
16451661
assert(e->Runions.more == oldRmore);
1646-
else if (limited || !has_exists)
1662+
else if (limited)
16471663
e->Runions.more = oldRmore;
16481664
free_env(&se);
16491665
return sub;

test/subtype.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2768,3 +2768,17 @@ end
27682768
Tuple{Type{Complex{T}} where T, Type{Complex{T}} where T, Type{String}},
27692769
Tuple{Type{Complex{T}}, Type{Complex{T}}, Type{String}} where T
27702770
)
2771+
2772+
#issue 58129
2773+
for k in 1:500
2774+
@eval struct $(Symbol(:T58129, k)){T} end
2775+
end
2776+
let Tvar = TypeVar(:Tvar)
2777+
V = UnionAll(Tvar, Union{(@eval($(Symbol(:T58129, k)){$Tvar}) for k in 1:500)...})
2778+
@test Set{<:V} <: AbstractSet{<:V}
2779+
end
2780+
let Tvar1 = TypeVar(:Tvar1), Tvar2 = TypeVar(:Tvar2)
2781+
V1 = UnionAll(Tvar1, Union{(@eval($(Symbol(:T58129, k)){$Tvar1}) for k in 1:100)...})
2782+
V2 = UnionAll(Tvar2, Union{(@eval($(Symbol(:T58129, k)){$Tvar2}) for k in 1:100)...})
2783+
@test Set{<:V2} <: AbstractSet{<:V1}
2784+
end

0 commit comments

Comments
 (0)