Skip to content

Commit 3b088df

Browse files
committed
typeintersect: fix bounds merging during inner intersect_all.
This is a backport of #55299
1 parent 396b557 commit 3b088df

File tree

2 files changed

+88
-134
lines changed

2 files changed

+88
-134
lines changed

src/subtype.c

Lines changed: 52 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ typedef struct jl_varbinding_t {
6565
jl_value_t *lb;
6666
jl_value_t *ub;
6767
int8_t right; // whether this variable came from the right side of `A <: B`
68-
int8_t occurs; // occurs in any position
6968
int8_t occurs_inv; // occurs in invariant position
7069
int8_t occurs_cov; // # of occurrences in covariant position
7170
int8_t concrete; // 1 if another variable has a constraint forcing this one to be concrete
@@ -168,7 +167,7 @@ static int current_env_length(jl_stenv_t *e)
168167
typedef struct {
169168
int8_t *buf;
170169
int rdepth;
171-
int8_t _space[24]; // == 8 * 3
170+
int8_t _space[16]; // == 8 * 2
172171
jl_gcframe_t gcframe;
173172
jl_value_t *roots[24];
174173
} jl_savedenv_t;
@@ -197,7 +196,6 @@ static void re_save_env(jl_stenv_t *e, jl_savedenv_t *se, int root)
197196
roots[i++] = v->ub;
198197
roots[i++] = (jl_value_t*)v->innervars;
199198
}
200-
se->buf[j++] = v->occurs;
201199
se->buf[j++] = v->occurs_inv;
202200
se->buf[j++] = v->occurs_cov;
203201
v = v->prev;
@@ -278,7 +276,6 @@ static void restore_env(jl_stenv_t *e, jl_savedenv_t *se, int root) JL_NOTSAFEPO
278276
v->ub = roots[i++];
279277
v->innervars = (jl_array_t*)roots[i++];
280278
}
281-
v->occurs = se->buf[j++];
282279
v->occurs_inv = se->buf[j++];
283280
v->occurs_cov = se->buf[j++];
284281
v = v->prev;
@@ -289,15 +286,6 @@ static void restore_env(jl_stenv_t *e, jl_savedenv_t *se, int root) JL_NOTSAFEPO
289286
memset(&e->envout[e->envidx], 0, (e->envsz - e->envidx)*sizeof(void*));
290287
}
291288

292-
static void clean_occurs(jl_stenv_t *e)
293-
{
294-
jl_varbinding_t *v = e->vars;
295-
while (v) {
296-
v->occurs = 0;
297-
v = v->prev;
298-
}
299-
}
300-
301289
#define flip_offset(e) ((e)->Loffset *= -1)
302290

303291
// type utilities
@@ -586,6 +574,8 @@ static jl_value_t *simple_meet(jl_value_t *a, jl_value_t *b, int overesi)
586574

587575
static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param);
588576

577+
#define has_next_union_state(e, R) ((((R) ? &(e)->Runions : &(e)->Lunions)->more) != 0)
578+
589579
static int next_union_state(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
590580
{
591581
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
@@ -666,8 +656,6 @@ static int subtype_left_var(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int par
666656
// of determining whether the variable is concrete.
667657
static void record_var_occurrence(jl_varbinding_t *vb, jl_stenv_t *e, int param) JL_NOTSAFEPOINT
668658
{
669-
if (vb != NULL)
670-
vb->occurs = 1;
671659
if (vb != NULL && param) {
672660
// saturate counters at 2; we don't need values bigger than that
673661
if (param == 2 && e->invdepth > vb->depth0) {
@@ -898,7 +886,7 @@ static jl_unionall_t *unalias_unionall(jl_unionall_t *u, jl_stenv_t *e)
898886
static int subtype_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8_t R, int param)
899887
{
900888
u = unalias_unionall(u, e);
901-
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0,
889+
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0,
902890
e->invdepth, NULL, e->vars };
903891
JL_GC_PUSH4(&u, &vb.lb, &vb.ub, &vb.innervars);
904892
e->vars = &vb;
@@ -3198,7 +3186,7 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
31983186
{
31993187
jl_value_t *res = NULL;
32003188
jl_savedenv_t se;
3201-
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0,
3189+
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0,
32023190
e->invdepth, NULL, e->vars };
32033191
JL_GC_PUSH4(&res, &vb.lb, &vb.ub, &vb.innervars);
32043192
save_env(e, &se, 1);
@@ -3226,7 +3214,7 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
32263214
vb.ub = vb.var->ub;
32273215
}
32283216
restore_env(e, &se, vb.constraintkind == 1 ? 1 : 0);
3229-
vb.occurs = vb.occurs_cov = vb.occurs_inv = 0;
3217+
vb.occurs_cov = vb.occurs_inv = 0;
32303218
res = intersect_unionall_(t, u, e, R, param, &vb);
32313219
}
32323220
}
@@ -3893,73 +3881,12 @@ static jl_value_t *intersect(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int pa
38933881
return jl_bottom_type;
38943882
}
38953883

3896-
static int merge_env(jl_stenv_t *e, jl_savedenv_t *se, int count)
3884+
static int merge_env(jl_stenv_t *e, jl_savedenv_t *me, jl_savedenv_t *se, int count)
38973885
{
3898-
if (count == 0)
3899-
alloc_env(e, se, 1);
3900-
jl_value_t **roots = NULL;
3901-
int nroots = 0;
3902-
if (se->gcframe.nroots == JL_GC_ENCODE_PUSHARGS(1)) {
3903-
jl_svec_t *sv = (jl_svec_t*)se->roots[0];
3904-
assert(jl_is_svec(sv));
3905-
roots = jl_svec_data(sv);
3906-
nroots = jl_svec_len(sv);
3907-
}
3908-
else {
3909-
roots = se->roots;
3910-
nroots = se->gcframe.nroots >> 2;
3911-
}
3912-
int n = 0;
3913-
jl_varbinding_t *v = e->vars;
3914-
v = e->vars;
3915-
while (v != NULL) {
3916-
if (count == 0) {
3917-
// need to initialize this
3918-
se->buf[n] = 0;
3919-
se->buf[n+1] = 0;
3920-
se->buf[n+2] = 0;
3921-
}
3922-
if (v->occurs) {
3923-
// only merge lb/ub/innervars if this var occurs.
3924-
jl_value_t *b1, *b2;
3925-
b1 = roots[n];
3926-
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
3927-
b2 = v->lb;
3928-
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
3929-
roots[n] = b1 ? simple_meet(b1, b2, 0) : b2;
3930-
b1 = roots[n+1];
3931-
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
3932-
b2 = v->ub;
3933-
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
3934-
roots[n+1] = b1 ? simple_join(b1, b2) : b2;
3935-
b1 = roots[n+2];
3936-
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
3937-
b2 = (jl_value_t*)v->innervars;
3938-
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
3939-
if (b2 && b1 != b2) {
3940-
if (b1)
3941-
jl_array_ptr_1d_append((jl_array_t*)b1, (jl_array_t*)b2);
3942-
else
3943-
roots[n+2] = b2;
3944-
}
3945-
// record the meeted vars.
3946-
se->buf[n] = 1;
3947-
}
3948-
// always merge occurs_inv/cov by max (never decrease)
3949-
if (v->occurs_inv > se->buf[n+1])
3950-
se->buf[n+1] = v->occurs_inv;
3951-
if (v->occurs_cov > se->buf[n+2])
3952-
se->buf[n+2] = v->occurs_cov;
3953-
n = n + 3;
3954-
v = v->prev;
3886+
if (count == 0) {
3887+
save_env(e, me, 1);
3888+
return 1;
39553889
}
3956-
assert(n == nroots); (void)nroots;
3957-
return count + 1;
3958-
}
3959-
3960-
// merge untouched vars' info.
3961-
static void final_merge_env(jl_stenv_t *e, jl_savedenv_t *me, jl_savedenv_t *se)
3962-
{
39633890
jl_value_t **merged = NULL;
39643891
jl_value_t **saved = NULL;
39653892
int nroots = 0;
@@ -3981,47 +3908,46 @@ static void final_merge_env(jl_stenv_t *e, jl_savedenv_t *me, jl_savedenv_t *se)
39813908
}
39823909
assert(nroots == current_env_length(e) * 3);
39833910
assert(nroots % 3 == 0);
3984-
for (int n = 0; n < nroots; n = n + 3) {
3985-
if (merged[n] == NULL)
3986-
merged[n] = saved[n];
3987-
if (merged[n+1] == NULL)
3988-
merged[n+1] = saved[n+1];
3989-
jl_value_t *b1, *b2;
3911+
int m = 0, n = 0;
3912+
jl_varbinding_t *v = e->vars;
3913+
while (v != NULL) {
3914+
jl_value_t *b0, *b1, *b2;
3915+
// merge `lb`
3916+
b0 = saved[n];
3917+
b1 = merged[n];
3918+
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
3919+
b2 = v->lb;
3920+
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
3921+
merged[n] = (b1 == b0 || b2 == b0) ? b0 : simple_meet(b1, b2, 0);
3922+
// merge `ub`
3923+
b0 = saved[n+1];
3924+
b1 = merged[n+1];
3925+
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
3926+
b2 = v->ub;
3927+
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
3928+
merged[n+1] = (b1 == b0 || b2 == b0) ? b0 : simple_join(b1, b2);
3929+
// merge `innervars`
39903930
b1 = merged[n+2];
39913931
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
3992-
b2 = saved[n+2];
3993-
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know this came from our GC frame
3932+
b2 = (jl_value_t*)v->innervars;
3933+
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
39943934
if (b2 && b1 != b2) {
39953935
if (b1)
39963936
jl_array_ptr_1d_append((jl_array_t*)b1, (jl_array_t*)b2);
39973937
else
39983938
merged[n+2] = b2;
39993939
}
4000-
me->buf[n] |= se->buf[n];
4001-
}
4002-
}
4003-
4004-
static void expand_local_env(jl_stenv_t *e, jl_value_t *res)
4005-
{
4006-
jl_varbinding_t *v = e->vars;
4007-
// Here we pull in some typevar missed in fastpath.
4008-
while (v != NULL) {
4009-
v->occurs = v->occurs || jl_has_typevar(res, v->var);
4010-
assert(v->occurs == 0 || v->occurs == 1);
4011-
v = v->prev;
4012-
}
4013-
v = e->vars;
4014-
while (v != NULL) {
4015-
if (v->occurs == 1) {
4016-
jl_varbinding_t *v2 = e->vars;
4017-
while (v2 != NULL) {
4018-
if (v2 != v && v2->occurs == 0)
4019-
v2->occurs = -(jl_has_typevar(v->lb, v2->var) || jl_has_typevar(v->ub, v2->var));
4020-
v2 = v2->prev;
4021-
}
4022-
}
3940+
// merge occurs_inv/cov by max (never decrease)
3941+
if (v->occurs_inv > me->buf[m])
3942+
me->buf[m] = v->occurs_inv;
3943+
if (v->occurs_cov > me->buf[m+1])
3944+
me->buf[m+1] = v->occurs_cov;
3945+
m = m + 3;
3946+
n = n + 3;
40233947
v = v->prev;
40243948
}
3949+
assert(n == nroots); (void)nroots;
3950+
return count + 1;
40253951
}
40263952

40273953
static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
@@ -4034,25 +3960,19 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
40343960
jl_savedenv_t se, me;
40353961
save_env(e, &se, 1);
40363962
int niter = 0, total_iter = 0;
4037-
clean_occurs(e);
40383963
is[0] = intersect(x, y, e, 0); // root
4039-
if (is[0] != jl_bottom_type) {
4040-
expand_local_env(e, is[0]);
4041-
niter = merge_env(e, &me, niter);
4042-
}
3964+
if (is[0] != jl_bottom_type)
3965+
niter = merge_env(e, &me, &se, niter);
40433966
restore_env(e, &se, 1);
40443967
while (next_union_state(e, 1)) {
40453968
if (e->emptiness_only && is[0] != jl_bottom_type)
40463969
break;
40473970
e->Runions.depth = 0;
40483971
e->Runions.more = 0;
40493972

4050-
clean_occurs(e);
40513973
is[1] = intersect(x, y, e, 0);
4052-
if (is[1] != jl_bottom_type) {
4053-
expand_local_env(e, is[1]);
4054-
niter = merge_env(e, &me, niter);
4055-
}
3974+
if (is[1] != jl_bottom_type)
3975+
niter = merge_env(e, &me, &se, niter);
40563976
restore_env(e, &se, 1);
40573977
if (is[0] == jl_bottom_type)
40583978
is[0] = is[1];
@@ -4061,13 +3981,18 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
40613981
is[0] = jl_type_union(is, 2);
40623982
}
40633983
total_iter++;
4064-
if (niter > 4 || total_iter > 400000) {
3984+
if (has_next_union_state(e, 1) && (niter > 4 || total_iter > 400000)) {
40653985
is[0] = y;
3986+
// we give up precise intersection here, just restore the saved env
3987+
restore_env(e, &se, 1);
3988+
if (niter > 0) {
3989+
free_env(&me);
3990+
niter = 0;
3991+
}
40663992
break;
40673993
}
40683994
}
40693995
if (niter) {
4070-
final_merge_env(e, &me, &se);
40713996
restore_env(e, &me, 1);
40723997
free_env(&me);
40733998
}
@@ -4552,7 +4477,7 @@ static jl_value_t *_widen_diagonal(jl_value_t *t, jl_varbinding_t *troot) {
45524477

45534478
static jl_value_t *widen_diagonal(jl_value_t *t, jl_unionall_t *u, jl_varbinding_t *troot)
45544479
{
4555-
jl_varbinding_t vb = { u->var, NULL, NULL, 1, 0, 0, 0, 0, 0, 0, 0, 0, NULL, troot };
4480+
jl_varbinding_t vb = { u->var, NULL, NULL, 1, 0, 0, 0, 0, 0, 0, 0, NULL, troot };
45564481
jl_value_t *nt;
45574482
JL_GC_PUSH2(&vb.innervars, &nt);
45584483
if (jl_is_unionall(u->body))

test/subtype.jl

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2374,12 +2374,41 @@ let S = Tuple{T2, V2} where {T2, N2, V2<:(Array{S2, N2} where {S2 <: T2})},
23742374
@testintersect(S, T, !Union{})
23752375
end
23762376

2377-
# A simple case which has a small local union.
2378-
# make sure the env is not widened too much when we intersect(Int8, Int8).
2379-
struct T48006{A1,A2,A3} end
2380-
@testintersect(Tuple{T48006{Float64, Int, S1}, Int} where {F1<:Real, S1<:Union{Int8, Val{F1}}},
2381-
Tuple{T48006{F2, I, S2}, I} where {F2<:Real, I<:Int, S2<:Union{Int8, Val{F2}}},
2382-
Tuple{T48006{Float64, Int, S1}, Int} where S1<:Union{Val{Float64}, Int8})
2377+
let S = Dict{Int, S1} where {F1, S1<:Union{Int8, Val{F1}}},
2378+
T = Dict{F2, S2} where {F2, S2<:Union{Int8, Val{F2}}}
2379+
@test_broken typeintersect(S, T) == Dict{Int, S} where S<:Union{Val{Int}, Int8}
2380+
@test typeintersect(T, S) == Dict{Int, S} where S<:Union{Val{Int}, Int8}
2381+
end
2382+
2383+
# Ensure inner `intersect_all` never under-esitimate.
2384+
let S = Tuple{F1, Dict{Int, S1}} where {F1, S1<:Union{Int8, Val{F1}}},
2385+
T = Tuple{Any, Dict{F2, S2}} where {F2, S2<:Union{Int8, Val{F2}}}
2386+
@test Tuple{Nothing, Dict{Int, Int8}} <: S
2387+
@test Tuple{Nothing, Dict{Int, Int8}} <: T
2388+
@test Tuple{Nothing, Dict{Int, Int8}} <: typeintersect(S, T)
2389+
@test Tuple{Nothing, Dict{Int, Int8}} <: typeintersect(T, S)
2390+
end
2391+
2392+
let S = Tuple{F1, Val{S1}} where {F1, S1<:Dict{F1}}
2393+
T = Tuple{Any, Val{S2}} where {F2, S2<:Union{map(T->Dict{T}, Base.BitInteger_types)...}}
2394+
ST = typeintersect(S, T)
2395+
TS = typeintersect(S, T)
2396+
for U in Base.BitInteger_types
2397+
@test Tuple{U, Val{Dict{U,Nothing}}} <: S
2398+
@test Tuple{U, Val{Dict{U,Nothing}}} <: T
2399+
@test Tuple{U, Val{Dict{U,Nothing}}} <: ST
2400+
@test Tuple{U, Val{Dict{U,Nothing}}} <: TS
2401+
end
2402+
end
2403+
2404+
#issue 55206
2405+
struct T55206{A,B<:Complex{A},C<:Union{Dict{Nothing},Dict{A}}} end
2406+
@testintersect(T55206, T55206{<:Any,<:Any,<:Dict{Nothing}}, T55206{A,<:Complex{A},<:Dict{Nothing}} where {A})
2407+
@testintersect(
2408+
Tuple{Dict{Int8, Int16}, Val{S1}} where {F1, S1<:AbstractSet{F1}},
2409+
Tuple{Dict{T1, T2}, Val{S2}} where {T1, T2, S2<:Union{Set{T1},Set{T2}}},
2410+
Tuple{Dict{Int8, Int16}, Val{S1}} where {S1<:Union{Set{Int8},Set{Int16}}}
2411+
)
23832412

23842413
f48167(::Type{Val{L2}}, ::Type{Union{Val{L1}, Set{R}}}) where {L1, R, L2<:L1} = 1
23852414
f48167(::Type{Val{L1}}, ::Type{Union{Val{L2}, Set{R}}}) where {L1, R, L2<:L1} = 2
@@ -2548,7 +2577,7 @@ end
25482577
let T = Tuple{Union{Type{T}, Type{S}}, Union{Val{T}, Val{S}}, Union{Val{T}, S}} where T<:Val{A} where A where S<:Val,
25492578
S = Tuple{Type{T}, T, Val{T}} where T<:(Val{S} where S<:Val)
25502579
# optimal = Union{}?
2551-
@test typeintersect(T, S) == Tuple{Type{A}, Union{Val{A}, Val{S} where S<:Union{Val, A}, Val{x} where x<:Val, Val{x} where x<:Union{Val, A}}, Val{A}} where A<:(Val{S} where S<:Val)
2580+
@test typeintersect(T, S) == Tuple{Type{T}, Union{Val{T}, Val{S}}, Val{T}} where {S<:Val, T<:Val}
25522581
@test typeintersect(S, T) == Tuple{Type{T}, Union{Val{T}, Val{S}}, Val{T}} where {T<:Val, S<:(Union{Val{A}, Val} where A)}
25532582
end
25542583

0 commit comments

Comments
 (0)