Skip to content

Commit 9b73611

Browse files
authored
SROA: don't use unswitchtupleunion and explicitly use type name only (#50522)
Since construction of `UnionAll` of `Union`s can be expensive. The SROA pass just needs to look at type name information and do not need to propagate full type objects. - xref: <#50511 (comment)> - closes #50511
1 parent 824cdf1 commit 9b73611

File tree

4 files changed

+20
-58
lines changed

4 files changed

+20
-58
lines changed

base/compiler/ssair/passes.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ function try_compute_field(ir::Union{IncrementalCompact,IRCode}, @nospecialize(f
6464
end
6565

6666
# assume `stmt` is a call of `getfield`/`setfield!`/`isdefined`
67-
function try_compute_fieldidx_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr, typ::DataType)
67+
function try_compute_fieldidx_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr, @nospecialize(typ))
6868
field = try_compute_field(ir, stmt.args[3])
6969
return try_compute_fieldidx(typ, field)
7070
end
@@ -1106,24 +1106,24 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
11061106
val = stmt.args[2]
11071107
end
11081108
struct_typ = widenconst(argextype(val, compact))
1109-
struct_typ_unwrapped = unwrap_unionall(struct_typ)
1110-
if isa(struct_typ, Union)
1111-
struct_typ_unwrapped = unswitchtypeunion(struct_typ_unwrapped)
1112-
end
1113-
if isa(struct_typ_unwrapped, Union) && is_isdefined
1114-
lift_comparison!(isdefined, compact, idx, stmt, lifting_cache, 𝕃ₒ)
1109+
struct_typ_name = argument_datatype(struct_typ)
1110+
if struct_typ_name === nothing
1111+
if isa(struct_typ, Union)
1112+
lift_comparison!(isdefined, compact, idx, stmt, lifting_cache, 𝕃ₒ)
1113+
end
11151114
continue
1115+
else
1116+
struct_typ_name = struct_typ_name.name
11161117
end
1117-
isa(struct_typ_unwrapped, DataType) || continue
11181118

1119-
struct_typ_unwrapped.name.atomicfields == C_NULL || continue # TODO: handle more
1119+
struct_typ_name.atomicfields == C_NULL || continue # TODO: handle more
11201120
if !((field_ordering === :unspecified) ||
11211121
(field_ordering isa Const && field_ordering.val === :not_atomic))
11221122
continue
11231123
end
11241124

11251125
# analyze this mutable struct here for the later pass
1126-
if ismutabletype(struct_typ_unwrapped)
1126+
if ismutabletypename(struct_typ_name)
11271127
isa(val, SSAValue) || continue
11281128
let intermediaries = SPCSet()
11291129
callback = IntermediaryCollector(intermediaries)
@@ -1153,7 +1153,7 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
11531153
end
11541154

11551155
# perform SROA on immutable structs here on
1156-
field = try_compute_fieldidx_stmt(compact, stmt, struct_typ_unwrapped)
1156+
field = try_compute_fieldidx_stmt(compact, stmt, struct_typ)
11571157
field === nothing && continue
11581158

11591159
leaves, visited_philikes = collect_leaves(compact, val, struct_typ, 𝕃ₒ, phi_or_ifelse_predecessors)

base/compiler/tfuncs.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -877,13 +877,10 @@ function fieldcount_noerror(@nospecialize t)
877877
if t === nothing
878878
return nothing
879879
end
880-
t = t::DataType
881880
elseif t === Union{}
882881
return 0
883882
end
884-
if !(t isa DataType)
885-
return nothing
886-
end
883+
t isa DataType || return nothing
887884
if t.name === _NAMEDTUPLE_NAME
888885
names, types = t.parameters
889886
if names isa Tuple
@@ -892,17 +889,16 @@ function fieldcount_noerror(@nospecialize t)
892889
if types isa DataType && types <: Tuple
893890
return fieldcount_noerror(types)
894891
end
895-
abstr = true
896-
else
897-
abstr = isabstracttype(t) || (t.name === Tuple.name && isvatuple(t))
898-
end
899-
if abstr
892+
return nothing
893+
elseif isabstracttype(t) || (t.name === Tuple.name && isvatuple(t))
900894
return nothing
901895
end
902896
return isdefined(t, :types) ? length(t.types) : length(t.name.names)
903897
end
904898

905-
function try_compute_fieldidx(typ::DataType, @nospecialize(field))
899+
function try_compute_fieldidx(@nospecialize(typ), @nospecialize(field))
900+
typ = argument_datatype(typ)
901+
typ === nothing && return nothing
906902
if isa(field, Symbol)
907903
field = fieldindex(typ, field, false)
908904
field == 0 && return nothing

base/compiler/typeutils.jl

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -317,42 +317,6 @@ function unionall_depth(@nospecialize ua) # aka subtype_env_size
317317
return depth
318318
end
319319

320-
# convert a Union of same `UnionAll` types to the `UnionAll` type whose parameter is the Unions
321-
function unswitchtypeunion(u::Union, typename::Union{Nothing,Core.TypeName}=nothing)
322-
ts = uniontypes(u)
323-
n = -1
324-
for t in ts
325-
t isa DataType || return u
326-
if typename === nothing
327-
typename = t.name
328-
elseif typename !== t.name
329-
return u
330-
end
331-
params = t.parameters
332-
np = length(params)
333-
if np == 0 || isvarargtype(params[end])
334-
return u
335-
end
336-
if n == -1
337-
n = np
338-
elseif n np
339-
return u
340-
end
341-
end
342-
Head = (typename::Core.TypeName).wrapper
343-
hparams = Any[]
344-
for i = 1:n
345-
uparams = Any[]
346-
for t in ts
347-
tpᵢ = (t::DataType).parameters[i]
348-
tpᵢ isa Type || return u
349-
push!(uparams, tpᵢ)
350-
end
351-
push!(hparams, Union{uparams...})
352-
end
353-
return Head{hparams...}
354-
end
355-
356320
function unwraptv_ub(@nospecialize t)
357321
while isa(t, TypeVar)
358322
t = t.ub

base/reflection.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,9 +529,11 @@ function ismutabletype(@nospecialize t)
529529
@_total_meta
530530
t = unwrap_unionall(t)
531531
# TODO: what to do for `Union`?
532-
return isa(t, DataType) && t.name.flags & 0x2 == 0x2
532+
return isa(t, DataType) && ismutabletypename(t.name)
533533
end
534534

535+
ismutabletypename(tn::Core.TypeName) = tn.flags & 0x2 == 0x2
536+
535537
"""
536538
isstructtype(T) -> Bool
537539

0 commit comments

Comments
 (0)