Skip to content

Commit 4381c05

Browse files
N5N3KristofferC
authored andcommitted
typeintersect: fix bounds merging during inner intersect_all (#55299)
This PR reverts the optimization from 748149e (part of #48167), while keeping the fix for merging occurs_inv/occurs_cov, as that optimzation makes no sense especially when typevar occurs both inside and outside the inner intersection. Close #55206 (cherry picked from commit fb6b790)
1 parent e0b2828 commit 4381c05

File tree

2 files changed

+92
-141
lines changed

2 files changed

+92
-141
lines changed

src/subtype.c

Lines changed: 56 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ typedef struct jl_varbinding_t {
6565
jl_value_t *lb;
6666
jl_value_t *ub;
6767
int8_t right; // whether this variable came from the right side of `A <: B`
68-
int8_t occurs; // occurs in any position
6968
int8_t occurs_inv; // occurs in invariant position
7069
int8_t occurs_cov; // # of occurrences in covariant position
7170
int8_t concrete; // 1 if another variable has a constraint forcing this one to be concrete
@@ -179,7 +178,7 @@ static int current_env_length(jl_stenv_t *e)
179178
typedef struct {
180179
int8_t *buf;
181180
int rdepth;
182-
int8_t _space[32]; // == 8 * 4
181+
int8_t _space[24]; // == 8 * 3
183182
jl_gcframe_t gcframe;
184183
jl_value_t *roots[24]; // == 8 * 3
185184
} jl_savedenv_t;
@@ -208,7 +207,6 @@ static void re_save_env(jl_stenv_t *e, jl_savedenv_t *se, int root)
208207
roots[i++] = v->ub;
209208
roots[i++] = (jl_value_t*)v->innervars;
210209
}
211-
se->buf[j++] = v->occurs;
212210
se->buf[j++] = v->occurs_inv;
213211
se->buf[j++] = v->occurs_cov;
214212
se->buf[j++] = v->max_offset;
@@ -243,7 +241,7 @@ static void alloc_env(jl_stenv_t *e, jl_savedenv_t *se, int root)
243241
ct->gcstack = &se->gcframe;
244242
}
245243
}
246-
se->buf = (len > 8 ? (int8_t*)malloc_s(len * 4) : se->_space);
244+
se->buf = (len > 8 ? (int8_t*)malloc_s(len * 3) : se->_space);
247245
#ifdef __clang_gcanalyzer__
248246
memset(se->buf, 0, len * 3);
249247
#endif
@@ -290,7 +288,6 @@ static void restore_env(jl_stenv_t *e, jl_savedenv_t *se, int root) JL_NOTSAFEPO
290288
v->ub = roots[i++];
291289
v->innervars = (jl_array_t*)roots[i++];
292290
}
293-
v->occurs = se->buf[j++];
294291
v->occurs_inv = se->buf[j++];
295292
v->occurs_cov = se->buf[j++];
296293
v->max_offset = se->buf[j++];
@@ -302,15 +299,6 @@ static void restore_env(jl_stenv_t *e, jl_savedenv_t *se, int root) JL_NOTSAFEPO
302299
memset(&e->envout[e->envidx], 0, (e->envsz - e->envidx)*sizeof(void*));
303300
}
304301

305-
static void clean_occurs(jl_stenv_t *e)
306-
{
307-
jl_varbinding_t *v = e->vars;
308-
while (v) {
309-
v->occurs = 0;
310-
v = v->prev;
311-
}
312-
}
313-
314302
#define flip_offset(e) ((e)->Loffset *= -1)
315303

316304
// type utilities
@@ -599,6 +587,8 @@ static jl_value_t *simple_meet(jl_value_t *a, jl_value_t *b, int overesi)
599587

600588
static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param);
601589

590+
#define has_next_union_state(e, R) ((((R) ? &(e)->Runions : &(e)->Lunions)->more) != 0)
591+
602592
static int next_union_state(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
603593
{
604594
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
@@ -679,8 +669,6 @@ static int subtype_left_var(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int par
679669
// of determining whether the variable is concrete.
680670
static void record_var_occurrence(jl_varbinding_t *vb, jl_stenv_t *e, int param) JL_NOTSAFEPOINT
681671
{
682-
if (vb != NULL)
683-
vb->occurs = 1;
684672
if (vb != NULL && param) {
685673
// saturate counters at 2; we don't need values bigger than that
686674
if (param == 2 && e->invdepth > vb->depth0) {
@@ -915,7 +903,7 @@ static jl_unionall_t *unalias_unionall(jl_unionall_t *u, jl_stenv_t *e)
915903
static int subtype_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8_t R, int param)
916904
{
917905
u = unalias_unionall(u, e);
918-
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, 0, 0,
906+
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, 0,
919907
e->invdepth, NULL, e->vars };
920908
JL_GC_PUSH4(&u, &vb.lb, &vb.ub, &vb.innervars);
921909
e->vars = &vb;
@@ -3312,7 +3300,7 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
33123300
{
33133301
jl_value_t *res = NULL;
33143302
jl_savedenv_t se;
3315-
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, 0, 0,
3303+
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, 0,
33163304
e->invdepth, NULL, e->vars };
33173305
JL_GC_PUSH4(&res, &vb.lb, &vb.ub, &vb.innervars);
33183306
save_env(e, &se, 1);
@@ -3341,7 +3329,7 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
33413329
vb.ub = vb.var->ub;
33423330
}
33433331
restore_env(e, &se, vb.constraintkind == 1 ? 1 : 0);
3344-
vb.occurs = vb.occurs_cov = vb.occurs_inv = 0;
3332+
vb.occurs_cov = vb.occurs_inv = 0;
33453333
res = intersect_unionall_(t, u, e, R, param, &vb);
33463334
}
33473335
}
@@ -4042,79 +4030,12 @@ static jl_value_t *intersect(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int pa
40424030
return jl_bottom_type;
40434031
}
40444032

4045-
static int merge_env(jl_stenv_t *e, jl_savedenv_t *se, int count)
4033+
static int merge_env(jl_stenv_t *e, jl_savedenv_t *me, jl_savedenv_t *se, int count)
40464034
{
4047-
if (count == 0)
4048-
alloc_env(e, se, 1);
4049-
jl_value_t **roots = NULL;
4050-
int nroots = 0;
4051-
if (se->gcframe.nroots == JL_GC_ENCODE_PUSHARGS(1)) {
4052-
jl_svec_t *sv = (jl_svec_t*)se->roots[0];
4053-
assert(jl_is_svec(sv));
4054-
roots = jl_svec_data(sv);
4055-
nroots = jl_svec_len(sv);
4056-
}
4057-
else {
4058-
roots = se->roots;
4059-
nroots = se->gcframe.nroots >> 2;
4060-
}
4061-
int m = 0, n = 0;
4062-
jl_varbinding_t *v = e->vars;
4063-
while (v != NULL) {
4064-
if (count == 0) {
4065-
// need to initialize this
4066-
se->buf[m] = 0;
4067-
se->buf[m+1] = 0;
4068-
se->buf[m+2] = 0;
4069-
se->buf[m+3] = v->max_offset;
4070-
}
4071-
jl_value_t *b1, *b2;
4072-
if (v->occurs) {
4073-
// only merge lb/ub if this var occurs.
4074-
b1 = roots[n];
4075-
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
4076-
b2 = v->lb;
4077-
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
4078-
roots[n] = b1 ? simple_meet(b1, b2, 0) : b2;
4079-
b1 = roots[n+1];
4080-
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
4081-
b2 = v->ub;
4082-
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
4083-
roots[n+1] = b1 ? simple_join(b1, b2) : b2;
4084-
// record the meeted vars.
4085-
se->buf[m] = 1;
4086-
}
4087-
// `innervars` might be re-sorted inside `finish_unionall`.
4088-
// We'd better always merge it.
4089-
b1 = roots[n+2];
4090-
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
4091-
b2 = (jl_value_t*)v->innervars;
4092-
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
4093-
if (b2 && b1 != b2) {
4094-
if (b1)
4095-
jl_array_ptr_1d_append((jl_array_t*)b1, (jl_array_t*)b2);
4096-
else
4097-
roots[n+2] = b2;
4098-
}
4099-
// always merge occurs_inv/cov by max (never decrease)
4100-
if (v->occurs_inv > se->buf[m+1])
4101-
se->buf[m+1] = v->occurs_inv;
4102-
if (v->occurs_cov > se->buf[m+2])
4103-
se->buf[m+2] = v->occurs_cov;
4104-
// always merge max_offset by min
4105-
if (!v->intersected && v->max_offset < se->buf[m+3])
4106-
se->buf[m+3] = v->max_offset;
4107-
m = m + 4;
4108-
n = n + 3;
4109-
v = v->prev;
4035+
if (count == 0) {
4036+
save_env(e, me, 1);
4037+
return 1;
41104038
}
4111-
assert(n == nroots); (void)nroots;
4112-
return count + 1;
4113-
}
4114-
4115-
// merge untouched vars' info.
4116-
static void final_merge_env(jl_stenv_t *e, jl_savedenv_t *me, jl_savedenv_t *se)
4117-
{
41184039
jl_value_t **merged = NULL;
41194040
jl_value_t **saved = NULL;
41204041
int nroots = 0;
@@ -4136,47 +4057,49 @@ static void final_merge_env(jl_stenv_t *e, jl_savedenv_t *me, jl_savedenv_t *se)
41364057
}
41374058
assert(nroots == current_env_length(e) * 3);
41384059
assert(nroots % 3 == 0);
4139-
for (int n = 0, m = 0; n < nroots; n += 3, m += 4) {
4140-
if (merged[n] == NULL)
4141-
merged[n] = saved[n];
4142-
if (merged[n+1] == NULL)
4143-
merged[n+1] = saved[n+1];
4144-
jl_value_t *b1, *b2;
4060+
int m = 0, n = 0;
4061+
jl_varbinding_t *v = e->vars;
4062+
while (v != NULL) {
4063+
jl_value_t *b0, *b1, *b2;
4064+
// merge `lb`
4065+
b0 = saved[n];
4066+
b1 = merged[n];
4067+
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
4068+
b2 = v->lb;
4069+
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
4070+
merged[n] = (b1 == b0 || b2 == b0) ? b0 : simple_meet(b1, b2, 0);
4071+
// merge `ub`
4072+
b0 = saved[n+1];
4073+
b1 = merged[n+1];
4074+
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
4075+
b2 = v->ub;
4076+
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
4077+
merged[n+1] = (b1 == b0 || b2 == b0) ? b0 : simple_join(b1, b2);
4078+
// merge `innervars`
41454079
b1 = merged[n+2];
41464080
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
4147-
b2 = saved[n+2];
4148-
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know this came from our GC frame
4081+
b2 = (jl_value_t*)v->innervars;
4082+
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
41494083
if (b2 && b1 != b2) {
41504084
if (b1)
41514085
jl_array_ptr_1d_append((jl_array_t*)b1, (jl_array_t*)b2);
41524086
else
41534087
merged[n+2] = b2;
41544088
}
4155-
me->buf[m] |= se->buf[m];
4156-
}
4157-
}
4158-
4159-
static void expand_local_env(jl_stenv_t *e, jl_value_t *res)
4160-
{
4161-
jl_varbinding_t *v = e->vars;
4162-
// Here we pull in some typevar missed in fastpath.
4163-
while (v != NULL) {
4164-
v->occurs = v->occurs || jl_has_typevar(res, v->var);
4165-
assert(v->occurs == 0 || v->occurs == 1);
4166-
v = v->prev;
4167-
}
4168-
v = e->vars;
4169-
while (v != NULL) {
4170-
if (v->occurs == 1) {
4171-
jl_varbinding_t *v2 = e->vars;
4172-
while (v2 != NULL) {
4173-
if (v2 != v && v2->occurs == 0)
4174-
v2->occurs = -(jl_has_typevar(v->lb, v2->var) || jl_has_typevar(v->ub, v2->var));
4175-
v2 = v2->prev;
4176-
}
4177-
}
4089+
// merge occurs_inv/cov by max (never decrease)
4090+
if (v->occurs_inv > me->buf[m])
4091+
me->buf[m] = v->occurs_inv;
4092+
if (v->occurs_cov > me->buf[m+1])
4093+
me->buf[m+1] = v->occurs_cov;
4094+
// merge max_offset by min
4095+
if (!v->intersected && v->max_offset < me->buf[m+2])
4096+
me->buf[m+2] = v->max_offset;
4097+
m = m + 3;
4098+
n = n + 3;
41784099
v = v->prev;
41794100
}
4101+
assert(n == nroots); (void)nroots;
4102+
return count + 1;
41804103
}
41814104

41824105
static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
@@ -4189,25 +4112,19 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
41894112
jl_savedenv_t se, me;
41904113
save_env(e, &se, 1);
41914114
int niter = 0, total_iter = 0;
4192-
clean_occurs(e);
41934115
is[0] = intersect(x, y, e, 0); // root
4194-
if (is[0] != jl_bottom_type) {
4195-
expand_local_env(e, is[0]);
4196-
niter = merge_env(e, &me, niter);
4197-
}
4116+
if (is[0] != jl_bottom_type)
4117+
niter = merge_env(e, &me, &se, niter);
41984118
restore_env(e, &se, 1);
41994119
while (next_union_state(e, 1)) {
42004120
if (e->emptiness_only && is[0] != jl_bottom_type)
42014121
break;
42024122
e->Runions.depth = 0;
42034123
e->Runions.more = 0;
42044124

4205-
clean_occurs(e);
42064125
is[1] = intersect(x, y, e, 0);
4207-
if (is[1] != jl_bottom_type) {
4208-
expand_local_env(e, is[1]);
4209-
niter = merge_env(e, &me, niter);
4210-
}
4126+
if (is[1] != jl_bottom_type)
4127+
niter = merge_env(e, &me, &se, niter);
42114128
restore_env(e, &se, 1);
42124129
if (is[0] == jl_bottom_type)
42134130
is[0] = is[1];
@@ -4216,13 +4133,18 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
42164133
is[0] = jl_type_union(is, 2);
42174134
}
42184135
total_iter++;
4219-
if (niter > 4 || total_iter > 400000) {
4136+
if (has_next_union_state(e, 1) && (niter > 4 || total_iter > 400000)) {
42204137
is[0] = y;
4138+
// we give up precise intersection here, just restore the saved env
4139+
restore_env(e, &se, 1);
4140+
if (niter > 0) {
4141+
free_env(&me);
4142+
niter = 0;
4143+
}
42214144
break;
42224145
}
42234146
}
42244147
if (niter) {
4225-
final_merge_env(e, &me, &se);
42264148
restore_env(e, &me, 1);
42274149
free_env(&me);
42284150
}
@@ -4707,7 +4629,7 @@ static jl_value_t *_widen_diagonal(jl_value_t *t, jl_varbinding_t *troot) {
47074629

47084630
static jl_value_t *widen_diagonal(jl_value_t *t, jl_unionall_t *u, jl_varbinding_t *troot)
47094631
{
4710-
jl_varbinding_t vb = { u->var, NULL, NULL, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, NULL, troot };
4632+
jl_varbinding_t vb = { u->var, NULL, NULL, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, NULL, troot };
47114633
jl_value_t *nt;
47124634
JL_GC_PUSH2(&vb.innervars, &nt);
47134635
if (jl_is_unionall(u->body))

test/subtype.jl

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,12 +2380,41 @@ let S = Tuple{T2, V2} where {T2, N2, V2<:(Array{S2, N2} where {S2 <: T2})},
23802380
@testintersect(S, T, !Union{})
23812381
end
23822382

2383-
# A simple case which has a small local union.
2384-
# make sure the env is not widened too much when we intersect(Int8, Int8).
2385-
struct T48006{A1,A2,A3} end
2386-
@testintersect(Tuple{T48006{Float64, Int, S1}, Int} where {F1<:Real, S1<:Union{Int8, Val{F1}}},
2387-
Tuple{T48006{F2, I, S2}, I} where {F2<:Real, I<:Int, S2<:Union{Int8, Val{F2}}},
2388-
Tuple{T48006{Float64, Int, S1}, Int} where S1<:Union{Val{Float64}, Int8})
2383+
let S = Dict{Int, S1} where {F1, S1<:Union{Int8, Val{F1}}},
2384+
T = Dict{F2, S2} where {F2, S2<:Union{Int8, Val{F2}}}
2385+
@test_broken typeintersect(S, T) == Dict{Int, S} where S<:Union{Val{Int}, Int8}
2386+
@test typeintersect(T, S) == Dict{Int, S} where S<:Union{Val{Int}, Int8}
2387+
end
2388+
2389+
# Ensure inner `intersect_all` never under-esitimate.
2390+
let S = Tuple{F1, Dict{Int, S1}} where {F1, S1<:Union{Int8, Val{F1}}},
2391+
T = Tuple{Any, Dict{F2, S2}} where {F2, S2<:Union{Int8, Val{F2}}}
2392+
@test Tuple{Nothing, Dict{Int, Int8}} <: S
2393+
@test Tuple{Nothing, Dict{Int, Int8}} <: T
2394+
@test Tuple{Nothing, Dict{Int, Int8}} <: typeintersect(S, T)
2395+
@test Tuple{Nothing, Dict{Int, Int8}} <: typeintersect(T, S)
2396+
end
2397+
2398+
let S = Tuple{F1, Val{S1}} where {F1, S1<:Dict{F1}}
2399+
T = Tuple{Any, Val{S2}} where {F2, S2<:Union{map(T->Dict{T}, Base.BitInteger_types)...}}
2400+
ST = typeintersect(S, T)
2401+
TS = typeintersect(S, T)
2402+
for U in Base.BitInteger_types
2403+
@test Tuple{U, Val{Dict{U,Nothing}}} <: S
2404+
@test Tuple{U, Val{Dict{U,Nothing}}} <: T
2405+
@test Tuple{U, Val{Dict{U,Nothing}}} <: ST
2406+
@test Tuple{U, Val{Dict{U,Nothing}}} <: TS
2407+
end
2408+
end
2409+
2410+
#issue 55206
2411+
struct T55206{A,B<:Complex{A},C<:Union{Dict{Nothing},Dict{A}}} end
2412+
@testintersect(T55206, T55206{<:Any,<:Any,<:Dict{Nothing}}, T55206{A,<:Complex{A},<:Dict{Nothing}} where {A})
2413+
@testintersect(
2414+
Tuple{Dict{Int8, Int16}, Val{S1}} where {F1, S1<:AbstractSet{F1}},
2415+
Tuple{Dict{T1, T2}, Val{S2}} where {T1, T2, S2<:Union{Set{T1},Set{T2}}},
2416+
Tuple{Dict{Int8, Int16}, Val{S1}} where {S1<:Union{Set{Int8},Set{Int16}}}
2417+
)
23892418

23902419
f48167(::Type{Val{L2}}, ::Type{Union{Val{L1}, Set{R}}}) where {L1, R, L2<:L1} = 1
23912420
f48167(::Type{Val{L1}}, ::Type{Union{Val{L2}, Set{R}}}) where {L1, R, L2<:L1} = 2
@@ -2554,7 +2583,7 @@ end
25542583
let T = Tuple{Union{Type{T}, Type{S}}, Union{Val{T}, Val{S}}, Union{Val{T}, S}} where T<:Val{A} where A where S<:Val,
25552584
S = Tuple{Type{T}, T, Val{T}} where T<:(Val{S} where S<:Val)
25562585
# optimal = Union{}?
2557-
@test typeintersect(T, S) == Tuple{Type{A}, Union{Val{A}, Val{S} where S<:Union{Val, A}, Val{x} where x<:Val, Val{x} where x<:Union{Val, A}}, Val{A}} where A<:(Val{S} where S<:Val)
2586+
@test typeintersect(T, S) == Tuple{Type{T}, Union{Val{T}, Val{S}}, Val{T}} where {S<:Val, T<:Val}
25582587
@test typeintersect(S, T) == Tuple{Type{T}, Union{Val{T}, Val{S}}, Val{T}} where {T<:Val, S<:(Union{Val{A}, Val} where A)}
25592588
end
25602589

0 commit comments

Comments
 (0)