Skip to content

Commit 3995278

Browse files
authored
SROA: generalize unswitchtupleunion optimization (#50502)
This commit improves SROA pass by extending the `unswitchtupleunion` optimization to handle the general parametric types, e.g.: ```julia julia> struct A{T} x::T end; julia> function foo(a1, a2, c) t = c ? A(a1) : A(a2) return getfield(t, :x) end; julia> only(Base.code_ircode(foo, (Int,Float64,Bool); optimize_until="SROA")) ``` > Before ``` 2 1 ─ goto #3 if not _4 │ 2 ─ %2 = %new(A{Int64}, _2)::A{Int64} │╻ A └── goto #4 │ 3 ─ %4 = %new(A{Float64}, _3)::A{Float64} │╻ A 4 ┄ %5 = φ (#2 => %2, #3 => %4)::Union{A{Float64}, A{Int64}} │ 3 │ %6 = Main.getfield(%5, :x)::Union{Float64, Int64} │ └── return %6 │ => Union{Float64, Int64} ``` > After ``` julia> only(Base.code_ircode(foo, (Int,Float64,Bool); optimize_until="SROA")) 2 1 ─ goto #3 if not _4 │ 2 ─ nothing::A{Int64} │╻ A └── goto #4 │ 3 ─ nothing::A{Float64} │╻ A 4 ┄ %8 = φ (#2 => _2, #3 => _3)::Union{Float64, Int64} │ │ nothing::Union{A{Float64}, A{Int64}} 3 │ %6 = %8::Union{Float64, Int64} │ └── return %6 │ => Union{Float64, Int64} ```
1 parent 0f6bfd6 commit 3995278

File tree

3 files changed

+59
-21
lines changed

3 files changed

+59
-21
lines changed

base/compiler/ssair/passes.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,8 +1107,8 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
11071107
end
11081108
struct_typ = widenconst(argextype(val, compact))
11091109
struct_typ_unwrapped = unwrap_unionall(struct_typ)
1110-
if isa(struct_typ, Union) && struct_typ <: Tuple
1111-
struct_typ_unwrapped = unswitchtupleunion(struct_typ_unwrapped)
1110+
if isa(struct_typ, Union)
1111+
struct_typ_unwrapped = unswitchtypeunion(struct_typ_unwrapped)
11121112
end
11131113
if isa(struct_typ_unwrapped, Union) && is_isdefined
11141114
lift_comparison!(isdefined, compact, idx, stmt, lifting_cache, 𝕃ₒ)

base/compiler/typeutils.jl

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -317,33 +317,40 @@ function unionall_depth(@nospecialize ua) # aka subtype_env_size
317317
return depth
318318
end
319319

320-
# convert a Union of Tuple types to a Tuple of Unions
321-
unswitchtupleunion(u::Union) = unswitchtypeunion(u, Tuple.name)
322-
320+
# convert a Union of same `UnionAll` types to the `UnionAll` type whose parameter is the Unions
323321
function unswitchtypeunion(u::Union, typename::Union{Nothing,Core.TypeName}=nothing)
324322
ts = uniontypes(u)
325323
n = -1
326324
for t in ts
327-
if t isa DataType
328-
if typename === nothing
329-
typename = t.name
330-
elseif typename !== t.name
331-
return u
332-
end
333-
if length(t.parameters) != 0 && !isvarargtype(t.parameters[end])
334-
if n == -1
335-
n = length(t.parameters)
336-
elseif n != length(t.parameters)
337-
return u
338-
end
339-
end
340-
else
325+
t isa DataType || return u
326+
if typename === nothing
327+
typename = t.name
328+
elseif typename !== t.name
329+
return u
330+
end
331+
params = t.parameters
332+
np = length(params)
333+
if np == 0 || isvarargtype(params[end])
334+
return u
335+
end
336+
if n == -1
337+
n = np
338+
elseif n np
341339
return u
342340
end
343341
end
344342
Head = (typename::Core.TypeName).wrapper
345-
unionparams = Any[ Union{Any[(t::DataType).parameters[i] for t in ts]...} for i in 1:n ]
346-
return Head{unionparams...}
343+
hparams = Any[]
344+
for i = 1:n
345+
uparams = Any[]
346+
for t in ts
347+
tpᵢ = (t::DataType).parameters[i]
348+
tpᵢ isa Type || return u
349+
push!(uparams, tpᵢ)
350+
end
351+
push!(hparams, Union{uparams...})
352+
end
353+
return Head{hparams...}
347354
end
348355

349356
function unwraptv_ub(@nospecialize t)

test/compiler/irpasses.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,3 +1390,34 @@ function wrap1_wrap1_wrapper(b, x, y)
13901390
end
13911391
@test wrap1_wrap1_wrapper(true, 1, 1.0) === 1.0
13921392
@test wrap1_wrap1_wrapper(false, 1, 1.0) === 1
1393+
1394+
# Test unswitching-union optimization within SRO Apass
1395+
function sroaunswitchuniontuple(c, x1, x2)
1396+
t = c ? (x1,) : (x2,)
1397+
return getfield(t, 1)
1398+
end
1399+
struct SROAUnswitchUnion1{T}
1400+
x::T
1401+
end
1402+
struct SROAUnswitchUnion2{S,T}
1403+
x::T
1404+
@inline SROAUnswitchUnion2{S}(x::T) where {S,T} = new{S,T}(x)
1405+
end
1406+
function sroaunswitchunionstruct1(c, x1, x2)
1407+
x = c ? SROAUnswitchUnion1(x1) : SROAUnswitchUnion1(x2)
1408+
return getfield(x, :x)
1409+
end
1410+
function sroaunswitchunionstruct2(c, x1, x2)
1411+
x = c ? SROAUnswitchUnion2{:a}(x1) : SROAUnswitchUnion2{:a}(x2)
1412+
return getfield(x, :x)
1413+
end
1414+
let src = code_typed1(sroaunswitchuniontuple, Tuple{Bool, Int, Float64})
1415+
@test count(isnew, src.code) == 0
1416+
@test count(iscall((src, getfield)), src.code) == 0
1417+
end
1418+
let src = code_typed1(sroaunswitchunionstruct1, Tuple{Bool, Int, Float64})
1419+
@test count(isnew, src.code) == 0
1420+
@test count(iscall((src, getfield)), src.code) == 0
1421+
end
1422+
@test sroaunswitchunionstruct2(true, 1, 1.0) === 1
1423+
@test sroaunswitchunionstruct2(false, 1, 1.0) === 1.0

0 commit comments

Comments
 (0)