Skip to content

Commit 58399e2

Browse files
serenity4aviatesk
andauthored
Extend PartialStruct to represent non-contiguously defined fields (#57304)
So far, `PartialStruct` has been unable to represent non-contiguously defined fields, where e.g. a struct would have fields 1 and 3 defined but not field 2. This PR extends it so that such information may be represented with `PartialStruct`, extending the applicability of optimizations e.g. introduced in #55297 by @aviatesk or #57222. The semantics of `new` prevent the creation of a struct with non-contiguously defined fields, therefore this change is mostly relevant to model mutable structs whose fields may be previously set or assumed to be defined after creation, or immutable structs whose creation is opaque. Notably, with this change we may now infer information about structs in the following case: ```julia mutable struct A; x; y; z; A() = new(); end function f() mut = A() # some opaque call preventing optimizations # who knows, maybe `identity` will set fields from `mut` in a future world age! invokelatest(identity, mut) isdefined(mut, :z) && isdefined(mut, :x) || return isdefined(mut, :x) & isdefined(mut, :z) # this now infers as `true` isdefined(mut, :y) # this does not end ``` whereas previously, only information gained successively with `isdefined(mut, :x) && isdefined(mut, :y) && isdefined(mut, :z)` could allow inference to model `mut` having its `z` field defined. --------- Co-authored-by: Cédric Belmant <[email protected]> Co-authored-by: Shuhei Kadowaki <[email protected]>
1 parent f7b986d commit 58399e2

File tree

13 files changed

+329
-93
lines changed

13 files changed

+329
-93
lines changed

Compiler/src/Compiler.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ using Base: @_foldable_meta, @_gc_preserve_begin, @_gc_preserve_end, @nospeciali
6767
partition_restriction, quoted, rename_unionall, rewrap_unionall, specialize_method,
6868
structdiff, tls_world_age, unconstrain_vararg_length, unionlen, uniontype_layout,
6969
uniontypes, unsafe_convert, unwrap_unionall, unwrapva, vect, widen_diagonal,
70-
_uncompressed_ir, maybe_add_binding_backedge!
70+
_uncompressed_ir, maybe_add_binding_backedge!, datatype_min_ninitialized,
71+
partialstruct_undef_length, partialstruct_init_undef
7172
using Base.Order
7273

7374
import Base: ==, _topmod, append!, convert, copy, copy!, findall, first, get, get!,

Compiler/src/abstractinterpretation.jl

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2148,23 +2148,13 @@ function form_partially_defined_struct(@nospecialize(obj), @nospecialize(name))
21482148
isabstracttype(objt) && return nothing
21492149
fldidx = try_compute_fieldidx(objt, name.val)
21502150
fldidx === nothing && return nothing
2151+
isa(obj, PartialStruct) && return define_field(obj, fldidx)
21512152
nminfld = datatype_min_ninitialized(objt)
2152-
if ismutabletype(objt)
2153-
# A mutable struct can have non-contiguous undefined fields, but `PartialStruct` cannot
2154-
# model such a state. So here `PartialStruct` can be used to represent only the
2155-
# objects where the field following the minimum initialized fields is also defined.
2156-
if fldidx nminfld+1
2157-
# if it is already represented as a `PartialStruct`, we can add one more
2158-
# `isdefined`-field information on top of those implied by its `fields`
2159-
if !(obj isa PartialStruct && fldidx == length(obj.fields)+1)
2160-
return nothing
2161-
end
2162-
end
2163-
else
2164-
fldidx > nminfld || return nothing
2165-
end
2166-
return PartialStruct(fallback_lattice, objt0, Any[obj isa PartialStruct && ilength(obj.fields) ?
2167-
obj.fields[i] : fieldtype(objt0,i) for i = 1:fldidx])
2153+
fldidx > nminfld || return nothing
2154+
undef = partialstruct_init_undef(objt, fldidx; all_defined = false)
2155+
undef[fldidx] = false
2156+
fields = Any[fieldtype(objt0, i) for i = 1:fldidx]
2157+
return PartialStruct(fallback_lattice, objt0, undef, fields)
21682158
end
21692159

21702160
function abstract_call_unionall(interp::AbstractInterpreter, argtypes::Vector{Any}, call::CallMeta)
@@ -3725,8 +3715,7 @@ end
37253715
@nospecializeinfer function widenreturn_partials(𝕃ᵢ::PartialsLattice, @nospecialize(rt), info::BestguessInfo)
37263716
if isa(rt, PartialStruct)
37273717
fields = copy(rt.fields)
3728-
anyrefine = !isvarargtype(rt.fields[end]) &&
3729-
length(rt.fields) > datatype_min_ninitialized(rt.typ)
3718+
anyrefine = refines_definedness_information(rt)
37303719
𝕃 = typeinf_lattice(info.interp)
37313720
= strictpartialorder(𝕃)
37323721
for i in 1:length(fields)
@@ -3738,7 +3727,7 @@ end
37383727
end
37393728
fields[i] = a
37403729
end
3741-
anyrefine && return PartialStruct(𝕃ᵢ, rt.typ, fields)
3730+
anyrefine && return PartialStruct(𝕃ᵢ, rt.typ, rt.undef, fields)
37423731
end
37433732
if isa(rt, PartialOpaque)
37443733
return rt # XXX: this case was missed in #39512

Compiler/src/tfuncs.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ end
439439
end
440440
elseif isa(arg1, PartialStruct)
441441
if !isvarargtype(arg1.fields[end])
442-
if 1 idx length(arg1.fields)
442+
if !is_field_maybe_undef(arg1, idx)
443443
return Const(true)
444444
end
445445
end
@@ -1141,8 +1141,8 @@ end
11411141
sty = unwrap_unionall(s)::DataType
11421142
if isa(name, Const)
11431143
nv = _getfield_fieldindex(sty, name)
1144-
if isa(nv, Int) && 1 <= nv <= length(s00.fields)
1145-
return unwrapva(s00.fields[nv])
1144+
if isa(nv, Int) && !is_field_maybe_undef(s00, nv)
1145+
return unwrapva(partialstruct_getfield(s00, nv))
11461146
end
11471147
end
11481148
s00 = s

Compiler/src/typeinfer.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter)
515515
rettype_const = result_type.parameters[1]
516516
const_flags = 0x2
517517
elseif isa(result_type, PartialStruct)
518-
rettype_const = result_type.fields
518+
rettype_const = (result_type.undef, result_type.fields)
519519
const_flags = 0x2
520520
elseif isa(result_type, InterConditional)
521521
rettype_const = result_type
@@ -959,8 +959,9 @@ function cached_return_type(code::CodeInstance)
959959
rettype_const = code.rettype_const
960960
# the second subtyping/egal conditions are necessary to distinguish usual cases
961961
# from rare cases when `Const` wrapped those extended lattice type objects
962-
if isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype)
963-
return PartialStruct(fallback_lattice, rettype, rettype_const)
962+
if isa(rettype_const, Tuple{BitVector, Vector{Any}}) && !(Tuple{BitVector, Vector{Any}} <: rettype)
963+
undef, fields = rettype_const
964+
return PartialStruct(fallback_lattice, rettype, undef, fields)
964965
elseif isa(rettype_const, PartialOpaque) && rettype <: Core.OpaqueClosure
965966
return rettype_const
966967
elseif isa(rettype_const, InterConditional) && rettype !== InterConditional

Compiler/src/typelattice.jl

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -318,15 +318,15 @@ end
318318
fields = vartyp.fields
319319
thenfields = thentype === Bottom ? nothing : copy(fields)
320320
elsefields = elsetype === Bottom ? nothing : copy(fields)
321-
for i in 1:length(fields)
322-
if i == fldidx
323-
thenfields === nothing || (thenfields[i] = thentype)
324-
elsefields === nothing || (elsefields[i] = elsetype)
325-
end
321+
undef = copy(vartyp.undef)
322+
if 1 fldidx length(fields)
323+
thenfields === nothing || (thenfields[fldidx] = thentype)
324+
elsefields === nothing || (elsefields[fldidx] = elsetype)
325+
undef[fldidx] = false
326326
end
327327
return Conditional(slot,
328-
thenfields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, thenfields),
329-
elsefields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, elsefields))
328+
thenfields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, undef, thenfields),
329+
elsefields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, undef, elsefields))
330330
else
331331
vartyp_widened = widenconst(vartyp)
332332
thenfields = thentype === Bottom ? nothing : Any[]
@@ -431,10 +431,14 @@ end
431431
return false
432432
end
433433
end
434-
for i in 1:length(b.fields)
435-
af = a.fields[i]
436-
bf = b.fields[i]
437-
if i == length(b.fields)
434+
na = length(a.fields)
435+
nb = length(b.fields)
436+
nmax = max(na, nb)
437+
for i in 1:nmax
438+
is_field_maybe_undef(a, i) is_field_maybe_undef(b, i) || return false
439+
af = partialstruct_getfield(a, i)
440+
bf = partialstruct_getfield(b, i)
441+
if i == na || i == nb
438442
if isvarargtype(af)
439443
# If `af` is vararg, so must bf by the <: above
440444
@assert isvarargtype(bf)
@@ -464,12 +468,15 @@ end
464468
nfields(a.val) == length(b.fields) || return false
465469
else
466470
widea <: wideb || return false
467-
# for structs we need to check that `a` has more information than `b` that may be partially initialized
468-
n_initialized(a) length(b.fields) || return false
471+
# for structs we need to check that `a` does not have less information than `b` that may be partially initialized
472+
n_initialized(a) n_initialized(b) || return false
469473
end
470474
nf = nfields(a.val)
471475
for i in 1:nf
472-
isdefined(a.val, i) || continue # since ∀ T Union{} ⊑ T
476+
if !isdefined(a.val, i)
477+
is_field_maybe_undef(b, i) || return false # conflicting defined-ness information
478+
continue # since ∀ T Union{} ⊑ T
479+
end
473480
i > length(b.fields) && break # `a` has more information than `b` that is partially initialized struct
474481
bfᵢ = b.fields[i]
475482
if i == nf
@@ -541,6 +548,7 @@ end
541548
if isa(a, PartialStruct)
542549
isa(b, PartialStruct) || return false
543550
length(a.fields) == length(b.fields) || return false
551+
a.undef == b.undef || return false
544552
widenconst(a) == widenconst(b) || return false
545553
a.fields === b.fields && return true # fast path
546554
for i in 1:length(a.fields)
@@ -747,9 +755,15 @@ end
747755
# The ::AbstractLattice argument is unused and simply serves to disambiguate
748756
# different instances of the compiler that may share the `Core.PartialStruct`
749757
# type.
750-
function Core.PartialStruct(::AbstractLattice, @nospecialize(typ), fields::Vector{Any})
758+
759+
function Core.PartialStruct(𝕃::AbstractLattice, @nospecialize(typ), fields::Vector{Any}; all_defined::Bool = true)
760+
undef = partialstruct_init_undef(typ, fields; all_defined)
761+
return PartialStruct(𝕃, typ, undef, fields)
762+
end
763+
764+
function Core.PartialStruct(::AbstractLattice, @nospecialize(typ), undef::BitVector, fields::Vector{Any})
751765
for i = 1:length(fields)
752766
assert_nested_slotwrapper(fields[i])
753767
end
754-
return Core._PartialStruct(typ, fields)
768+
return PartialStruct(typ, undef, fields)
755769
end

Compiler/src/typelimits.jl

Lines changed: 74 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -326,17 +326,74 @@ function n_initialized(t::Const)
326326
return something(findfirst(i::Int->!isdefined(t.val,i), 1:nf), nf+1)-1
327327
end
328328

329+
is_field_maybe_undef(t::Const, i) = !isdefined(t.val, i)
330+
331+
function n_initialized(pstruct::PartialStruct)
332+
i = findfirst(pstruct.undef)
333+
nmin = datatype_min_ninitialized(pstruct.typ)
334+
i === nothing && return max(length(pstruct.undef), nmin)
335+
n = i::Int - 1
336+
@assert n nmin
337+
n
338+
end
339+
340+
function is_field_maybe_undef(pstruct::PartialStruct, fi)
341+
fi 1 || return true
342+
fi length(pstruct.undef) && return pstruct.undef[fi]
343+
fi > datatype_min_ninitialized(pstruct.typ)
344+
end
345+
346+
function partialstruct_getfield(pstruct::PartialStruct, fi::Integer)
347+
@assert fi > 0
348+
fi length(pstruct.fields) && return pstruct.fields[fi]
349+
fieldtype(pstruct.typ, fi)
350+
end
351+
352+
function refines_definedness_information(pstruct::PartialStruct)
353+
nflds = length(pstruct.undef)
354+
something(findfirst(pstruct.undef), nflds + 1) - 1 > datatype_min_ninitialized(pstruct.typ)
355+
end
356+
357+
function define_field(pstruct::PartialStruct, fi::Int)
358+
if !is_field_maybe_undef(pstruct, fi)
359+
# no new information to be gained
360+
return nothing
361+
end
362+
363+
new = expand_partialstruct(pstruct, fi)
364+
if new === nothing
365+
new = PartialStruct(fallback_lattice, pstruct.typ, copy(pstruct.undef), copy(pstruct.fields))
366+
end
367+
new.undef[fi] = false
368+
return new
369+
end
370+
371+
function expand_partialstruct(pstruct::PartialStruct, until::Int)
372+
n = length(pstruct.undef)
373+
until n && return nothing
374+
375+
undef = partialstruct_init_undef(pstruct.typ, until; all_defined = false)
376+
for i in 1:n
377+
undef[i] &= pstruct.undef[i]
378+
end
379+
nf = length(pstruct.fields)
380+
typ = pstruct.typ
381+
fields = Any[i nf ? pstruct.fields[i] : fieldtype(typ, i) for i in 1:until]
382+
return PartialStruct(fallback_lattice, typ, undef, fields)
383+
end
384+
329385
# A simplified type_more_complex query over the extended lattice
330386
# (assumes typeb ⊑ typea)
331387
@nospecializeinfer function issimplertype(𝕃::AbstractLattice, @nospecialize(typea), @nospecialize(typeb))
332388
@assert !isa(typea, LimitedAccuracy) && !isa(typeb, LimitedAccuracy) "LimitedAccuracy not supported by simplertype lattice" # n.b. the caller was supposed to handle these
333389
typea === typeb && return true
334390
if typea isa PartialStruct
335391
aty = widenconst(typea)
336-
if typeb isa Const
337-
@assert length(typea.fields) n_initialized(typeb) "typeb ⊑ typea is assumed"
392+
if typeb isa Const || typeb isa PartialStruct
393+
@assert n_initialized(typea) n_initialized(typeb) "typeb ⊑ typea is assumed"
338394
elseif typeb isa PartialStruct
339-
@assert length(typea.fields) length(typeb.fields) "typeb ⊑ typea is assumed"
395+
@assert n_initialized(typea) n_initialized(typeb) &&
396+
all(b < a for (a, b) in zip(typea.undef, typeb.undef)) "typeb ⊑ typea is assumed"
340397
else
341398
return false
342399
end
@@ -591,17 +648,24 @@ end
591648
if typea isa PartialStruct
592649
if typeb isa PartialStruct
593650
nflds = min(length(typea.fields), length(typeb.fields))
651+
nundef = nflds - (isvarargtype(typea.fields[end]) && isvarargtype(typeb.fields[end]))
594652
else
595653
nflds = min(length(typea.fields), n_initialized(typeb::Const))
654+
nundef = nflds
596655
end
597656
elseif typeb isa PartialStruct
598657
nflds = min(n_initialized(typea::Const), length(typeb.fields))
658+
nundef = nflds
599659
else
600660
nflds = min(n_initialized(typea::Const), n_initialized(typeb::Const))
661+
nundef = nflds
601662
end
602663
nflds == 0 && return nothing
664+
_undef = partialstruct_init_undef(aty, nundef; all_defined = false)
603665
fields = Vector{Any}(undef, nflds)
604-
anyrefine = nflds > datatype_min_ninitialized(aty)
666+
fldmin = datatype_min_ninitialized(aty)
667+
n_initialized_merged = min(n_initialized(typea::Union{Const, PartialStruct}), n_initialized(typeb::Union{Const, PartialStruct}))
668+
anyrefine = n_initialized_merged > fldmin
605669
for i = 1:nflds
606670
ai = getfield_tfunc(𝕃, typea, Const(i))
607671
bi = getfield_tfunc(𝕃, typeb, Const(i))
@@ -633,12 +697,16 @@ end
633697
end
634698
end
635699
fields[i] = tyi
700+
if i nundef
701+
_undef[i] = is_field_maybe_undef(typea, i) || is_field_maybe_undef(typeb, i)
702+
end
636703
if !anyrefine
637704
anyrefine = has_nontrivial_extended_info(𝕃, tyi) || # extended information
638-
(𝕃, tyi, ft) # just a type-level information, but more precise than the declared type
705+
(𝕃, tyi, ft) || # just a type-level information, but more precise than the declared type
706+
!get(_undef, i, true) && i > fldmin # possibly uninitialized field is known to be initialized
639707
end
640708
end
641-
anyrefine && return PartialStruct(𝕃, aty, fields)
709+
anyrefine && return PartialStruct(𝕃, aty, _undef, fields)
642710
end
643711
return nothing
644712
end

Compiler/src/typeutils.jl

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -61,39 +61,6 @@ function isknownlength(t::DataType)
6161
return isdefined(va, :N) && va.N isa Int
6262
end
6363

64-
# Compute the minimum number of initialized fields for a particular datatype
65-
# (therefore also a lower bound on the number of fields)
66-
function datatype_min_ninitialized(@nospecialize t0)
67-
t = unwrap_unionall(t0)
68-
t isa DataType || return 0
69-
isabstracttype(t) && return 0
70-
if t.name === _NAMEDTUPLE_NAME
71-
names, types = t.parameters[1], t.parameters[2]
72-
if names isa Tuple
73-
return length(names)
74-
end
75-
t = argument_datatype(types)
76-
t isa DataType || return 0
77-
t.name === Tuple.name || return 0
78-
end
79-
if t.name === Tuple.name
80-
n = length(t.parameters)
81-
n == 0 && return 0
82-
va = t.parameters[n]
83-
if isvarargtype(va)
84-
n -= 1
85-
if isdefined(va, :N)
86-
va = va.N
87-
if va isa Int
88-
n += va
89-
end
90-
end
91-
end
92-
return n
93-
end
94-
return length(t.name.names) - t.name.n_uninitialized
95-
end
96-
9764
has_concrete_subtype(d::DataType) = d.flags & 0x0020 == 0x0020 # n.b. often computed only after setting the type and layout fields
9865

9966
# determine whether x is a valid lattice element

0 commit comments

Comments
 (0)