Skip to content

Commit f818842

Browse files
authored
optimizations: better modeling and codegen for apply and svec calls (#59548)
- Use svec instead of tuple for arguments (better match for ABI which will require boxes) - Directly forward single svec argument, both runtime and codegen, without copying. - Optimize all consistant builtin functions of constant arguments, not just ones with special tfuncs. Reducing code duplication and divergence. - Codegen for `svec()` directly, so optimizer can see each store (and doesn't have to build the whole thing on the stack first). Written with help by Claude
2 parents 067b013 + 18f1c26 commit f818842

File tree

8 files changed

+180
-38
lines changed

8 files changed

+180
-38
lines changed

Compiler/src/ssair/passes.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,49 @@ function perform_lifting!(compact::IncrementalCompact,
874874
return Pair{Any, PhiNest}(stmt_val, PhiNest(visited_philikes, lifted_philikes, lifted_leaves, reverse_mapping, walker_callback))
875875
end
876876

877+
function lift_apply_args!(compact::IncrementalCompact, idx::Int, stmt::Expr, 𝕃ₒ::AbstractLattice)
878+
# Handle _apply_iterate calls: convert arguments to use `Core.svec`. The behavior of Core.svec (with boxing) better matches the ABI of codegen.
879+
compact[idx] = nothing
880+
for i in 4:length(stmt.args) # Skip iterate function, f, and first iterator
881+
arg = stmt.args[i]
882+
arg_type = argextype(arg, compact)
883+
svec_args = nothing
884+
if isa(arg_type, DataType) && arg_type.name === Tuple.name
885+
if isa(arg, SSAValue)
886+
arg_stmt = compact[arg][:stmt]
887+
if is_known_call(arg_stmt, Core.tuple, compact)
888+
svec_args = copy(arg_stmt.args)
889+
end
890+
end
891+
if svec_args === nothing
892+
# Fallback path: generate getfield calls for tuple elements
893+
tuple_length = length(arg_type.parameters)
894+
if tuple_length > 0 && !isvarargtype(arg_type.parameters[tuple_length])
895+
svec_args = Vector{Any}(undef, tuple_length + 1)
896+
for j in 1:tuple_length
897+
getfield_call = Expr(:call, GlobalRef(Core, :getfield), arg, j)
898+
getfield_type = arg_type.parameters[j]
899+
inst = compact[SSAValue(idx)]
900+
getfield_ssa = insert_node!(compact, SSAValue(idx), NewInstruction(getfield_call, getfield_type, NoCallInfo(), inst[:line], inst[:flag]))
901+
svec_args[j + 1] = getfield_ssa
902+
end
903+
end
904+
end
905+
end
906+
# Create Core.svec call if we have arguments
907+
if svec_args !== nothing
908+
svec_args[1] = GlobalRef(Core, :svec)
909+
new_svec_call = Expr(:call)
910+
new_svec_call.args = svec_args
911+
inst = compact[SSAValue(idx)]
912+
new_svec_ssa = insert_node!(compact, SSAValue(idx), NewInstruction(new_svec_call, SimpleVector, NoCallInfo(), inst[:line], inst[:flag]))
913+
stmt.args[i] = new_svec_ssa
914+
end
915+
end
916+
compact[idx] = stmt
917+
nothing
918+
end
919+
877920
function lift_svec_ref!(compact::IncrementalCompact, idx::Int, stmt::Expr)
878921
length(stmt.args) != 3 && return
879922

@@ -1377,6 +1420,9 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
13771420
compact[SSAValue(idx)] = (compact[enter_ssa][:stmt]::EnterNode).scope
13781421
elseif isexpr(stmt, :new)
13791422
refine_new_effects!(𝕃ₒ, compact, idx, stmt)
1423+
elseif is_known_call(stmt, Core._apply_iterate, compact)
1424+
length(stmt.args) >= 4 || continue
1425+
lift_apply_args!(compact, idx, stmt, 𝕃ₒ)
13801426
end
13811427
continue
13821428
end

Compiler/src/tfuncs.jl

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,15 @@ end
580580
add_tfunc(nfields, 1, 1, nfields_tfunc, 1)
581581
add_tfunc(Core._expr, 1, INT_INF, @nospecs((𝕃::AbstractLattice, args...)->Expr), 100)
582582
add_tfunc(svec, 0, INT_INF, @nospecs((𝕃::AbstractLattice, args...)->SimpleVector), 20)
583+
584+
@nospecs function _svec_len_tfunc(𝕃::AbstractLattice, s)
585+
if isa(s, Const) && isa(s.val, SimpleVector)
586+
return Const(length(s.val))
587+
end
588+
return Int
589+
end
590+
add_tfunc(Core._svec_len, 1, 1, _svec_len_tfunc, 1)
591+
583592
@nospecs function _svec_ref_tfunc(𝕃::AbstractLattice, s, i)
584593
if isa(s, Const) && isa(i, Const)
585594
s, i = s.val, i.val
@@ -1960,15 +1969,8 @@ function tuple_tfunc(𝕃::AbstractLattice, argtypes::Vector{Any})
19601969
# UnionAll context is missing around this.
19611970
pop!(argtypes)
19621971
end
1963-
all_are_const = true
1964-
for i in 1:length(argtypes)
1965-
if !isa(argtypes[i], Const)
1966-
all_are_const = false
1967-
break
1968-
end
1969-
end
1970-
if all_are_const
1971-
return Const(ntuple(i::Int->argtypes[i].val, length(argtypes)))
1972+
if is_all_const_arg(argtypes, 1) # repeated from builtin_tfunction for the benefit of callers that use this tfunc directly
1973+
return Const(tuple(collect_const_args(argtypes, 1)...))
19721974
end
19731975
params = Vector{Any}(undef, length(argtypes))
19741976
anyinfo = false
@@ -2334,14 +2336,17 @@ function _builtin_nothrow(𝕃::AbstractLattice, @nospecialize(f::Builtin), argt
23342336
elseif f === Core.compilerbarrier
23352337
na == 2 || return false
23362338
return compilerbarrier_nothrow(argtypes[1], nothing)
2339+
elseif f === Core._svec_len
2340+
na == 1 || return false
2341+
return _svec_len_tfunc(𝕃, argtypes[1]) isa Const
23372342
elseif f === Core._svec_ref
23382343
na == 2 || return false
23392344
return _svec_ref_tfunc(𝕃, argtypes[1], argtypes[2]) isa Const
23402345
end
23412346
return false
23422347
end
23432348

2344-
# known to be always effect-free (in particular nothrow)
2349+
# known to be always effect-free (in particular also nothrow)
23452350
const _PURE_BUILTINS = Any[
23462351
tuple,
23472352
svec,
@@ -2370,6 +2375,8 @@ const _CONSISTENT_BUILTINS = Any[
23702375
donotdelete,
23712376
memoryrefnew,
23722377
memoryrefoffset,
2378+
Core._svec_len,
2379+
Core._svec_ref,
23732380
]
23742381

23752382
# known to be effect-free (but not necessarily nothrow)
@@ -2394,6 +2401,7 @@ const _EFFECT_FREE_BUILTINS = [
23942401
Core.throw_methoderror,
23952402
getglobal,
23962403
compilerbarrier,
2404+
Core._svec_len,
23972405
Core._svec_ref,
23982406
]
23992407

@@ -2428,6 +2436,7 @@ const _ARGMEM_BUILTINS = Any[
24282436
replacefield!,
24292437
setfield!,
24302438
swapfield!,
2439+
Core._svec_len,
24312440
Core._svec_ref,
24322441
]
24332442

@@ -2571,6 +2580,7 @@ const _EFFECTS_KNOWN_BUILTINS = Any[
25712580
# Core._primitivetype,
25722581
# Core._setsuper!,
25732582
# Core._structtype,
2583+
Core._svec_len,
25742584
Core._svec_ref,
25752585
# Core._typebody!,
25762586
Core._typevar,
@@ -2675,7 +2685,7 @@ function builtin_effects(𝕃::AbstractLattice, @nospecialize(f::Builtin), argty
26752685
else
26762686
if contains_is(_CONSISTENT_BUILTINS, f)
26772687
consistent = ALWAYS_TRUE
2678-
elseif f === memoryrefget || f === memoryrefset! || f === memoryref_isassigned || f === Core._svec_ref
2688+
elseif f === memoryrefget || f === memoryrefset! || f === memoryref_isassigned || f === Core._svec_len || f === Core._svec_ref
26792689
consistent = CONSISTENT_IF_INACCESSIBLEMEMONLY
26802690
elseif f === Core._typevar || f === Core.memorynew
26812691
consistent = CONSISTENT_IF_NOTRETURNED
@@ -2784,11 +2794,12 @@ end
27842794
function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any},
27852795
sv::Union{AbsIntState, Nothing})
27862796
𝕃ᵢ = typeinf_lattice(interp)
2787-
if isa(f, IntrinsicFunction)
2788-
if is_pure_intrinsic_infer(f) && all(@nospecialize(a) -> isa(a, Const), argtypes)
2789-
argvals = anymap(@nospecialize(a) -> (a::Const).val, argtypes)
2797+
# Early constant evaluation for foldable builtins with all const args
2798+
if isa(f, IntrinsicFunction) ? is_pure_intrinsic_infer(f) : (f in _PURE_BUILTINS || (f in _CONSISTENT_BUILTINS && f in _EFFECT_FREE_BUILTINS))
2799+
if is_all_const_arg(argtypes, 1)
2800+
argvals = collect_const_args(argtypes, 1)
27902801
try
2791-
# unroll a few cases which have specialized codegen
2802+
# unroll a few common cases for better codegen
27922803
if length(argvals) == 1
27932804
return Const(f(argvals[1]))
27942805
elseif length(argvals) == 2
@@ -2802,6 +2813,8 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp
28022813
return Bottom
28032814
end
28042815
end
2816+
end
2817+
if isa(f, IntrinsicFunction)
28052818
iidx = Int(reinterpret(Int32, f)) + 1
28062819
if iidx < 0 || iidx > length(T_IFUNC)
28072820
# unknown intrinsic
@@ -2828,6 +2841,7 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp
28282841
end
28292842
tf = T_FFUNC_VAL[fidx]
28302843
end
2844+
28312845
if hasvarargtype(argtypes)
28322846
if length(argtypes) - 1 > tf[2]
28332847
# definitely too many arguments

Compiler/test/effects.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1474,7 +1474,7 @@ end
14741474
let effects = Base.infer_effects((Core.SimpleVector,Int); optimize=false) do svec, i
14751475
Core._svec_ref(svec, i)
14761476
end
1477-
@test !Compiler.is_consistent(effects)
1477+
@test Compiler.is_consistent(effects)
14781478
@test Compiler.is_effect_free(effects)
14791479
@test !Compiler.is_nothrow(effects)
14801480
@test Compiler.is_terminates(effects)

base/essentials.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -979,11 +979,7 @@ setindex!(A::MemoryRef{Any}, @nospecialize(x)) = (memoryrefset!(A, x, :not_atomi
979979

980980
getindex(v::SimpleVector, i::Int) = (@_foldable_meta; Core._svec_ref(v, i))
981981
function length(v::SimpleVector)
982-
@_total_meta
983-
t = @_gc_preserve_begin v
984-
len = unsafe_load(Ptr{Int}(pointer_from_objref(v)))
985-
@_gc_preserve_end t
986-
return len
982+
Core._svec_len(v)
987983
end
988984
firstindex(v::SimpleVector) = 1
989985
lastindex(v::SimpleVector) = length(v)

src/builtin_proto.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ extern "C" {
2020
XX(_primitivetype,"_primitivetype") \
2121
XX(_setsuper,"_setsuper!") \
2222
XX(_structtype,"_structtype") \
23+
XX(_svec_len,"_svec_len") \
2324
XX(_svec_ref,"_svec_ref") \
2425
XX(_typebody,"_typebody!") \
2526
XX(_typevar,"_typevar") \

src/builtins.c

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -701,9 +701,15 @@ JL_CALLABLE(jl_f__apply_iterate)
701701
return (jl_value_t*)t;
702702
}
703703
}
704-
else if (f == BUILTIN(tuple) && jl_is_tuple(args[1])) {
705-
return args[1];
704+
else if (f == BUILTIN(tuple)) {
705+
if (jl_is_tuple(args[1]))
706+
return args[1];
707+
if (jl_is_svec(args[1]))
708+
return jl_f_tuple(NULL, jl_svec_data(args[1]), jl_svec_len(args[1]));
706709
}
710+
// optimization for `f(svec...)`
711+
if (jl_is_svec(args[1]))
712+
return jl_apply_generic(f, jl_svec_data(args[1]), jl_svec_len(args[1]));
707713
}
708714
// estimate how many real arguments we appear to have
709715
size_t precount = 1;
@@ -2151,6 +2157,14 @@ JL_CALLABLE(jl_f__compute_sparams)
21512157
return (jl_value_t*)env;
21522158
}
21532159

2160+
JL_CALLABLE(jl_f__svec_len)
2161+
{
2162+
JL_NARGS(_svec_len, 1, 1);
2163+
jl_svec_t *s = (jl_svec_t*)args[0];
2164+
JL_TYPECHK(_svec_len, simplevector, (jl_value_t*)s);
2165+
return jl_box_long(jl_svec_len(s));
2166+
}
2167+
21542168
JL_CALLABLE(jl_f__svec_ref)
21552169
{
21562170
JL_NARGS(_svec_ref, 2, 2);

src/cgutils.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2228,6 +2228,9 @@ static jl_cgval_t typed_load(jl_codectx_t &ctx, Value *ptr, Value *idx_0based, j
22282228
}
22292229
Value *instr = nullptr;
22302230
if (!isboxed && jl_is_genericmemoryref_type(jltype)) {
2231+
//We don't specify the stronger expected memory ordering here because of fears it may interfere with vectorization and other optimizations
2232+
//if (Order == AtomicOrdering::NotAtomic)
2233+
// Order = AtomicOrdering::Monotonic;
22312234
// load these FCA as individual fields, so LLVM does not need to split them later
22322235
Value *fld0 = ctx.builder.CreateStructGEP(elty, ptr, 0);
22332236
LoadInst *load0 = ctx.builder.CreateAlignedLoad(elty->getStructElementType(0), fld0, Align(alignment), false);
@@ -2403,11 +2406,26 @@ static jl_cgval_t typed_store(jl_codectx_t &ctx,
24032406
instr = load;
24042407
}
24052408
if (r) {
2406-
StoreInst *store = ctx.builder.CreateAlignedStore(r, ptr, Align(alignment));
2407-
store->setOrdering(Order == AtomicOrdering::NotAtomic && isboxed ? AtomicOrdering::Release : Order);
24082409
jl_aliasinfo_t ai = jl_aliasinfo_t::fromTBAA(ctx, tbaa);
24092410
ai.noalias = MDNode::concatenate(aliasscope, ai.noalias);
2410-
ai.decorateInst(store);
2411+
if (false && !isboxed && Order == AtomicOrdering::NotAtomic && jl_is_genericmemoryref_type(jltype)) {
2412+
// if enabled, store these FCA as individual fields, so LLVM does not need to split them later and they can use release ordering
2413+
assert(r->getType() == ctx.types().T_jlgenericmemory);
2414+
Value *f1 = ctx.builder.CreateExtractValue(r, 0);
2415+
Value *f2 = ctx.builder.CreateExtractValue(r, 1);
2416+
static_assert(offsetof(jl_genericmemoryref_t, ptr_or_offset) == 0, "wrong field order");
2417+
StoreInst *store = ctx.builder.CreateAlignedStore(f1, ctx.builder.CreateStructGEP(ctx.types().T_jlgenericmemory, ptr, 0), Align(alignment));
2418+
store->setOrdering(AtomicOrdering::Release);
2419+
ai.decorateInst(store);
2420+
store = ctx.builder.CreateAlignedStore(f2, ctx.builder.CreateStructGEP(ctx.types().T_jlgenericmemory, ptr, 1), Align(alignment));
2421+
store->setOrdering(AtomicOrdering::Release);
2422+
ai.decorateInst(store);
2423+
}
2424+
else {
2425+
StoreInst *store = ctx.builder.CreateAlignedStore(r, ptr, Align(alignment));
2426+
store->setOrdering(Order == AtomicOrdering::NotAtomic && isboxed ? AtomicOrdering::Release : Order);
2427+
ai.decorateInst(store);
2428+
}
24112429
}
24122430
else {
24132431
assert(Order == AtomicOrdering::NotAtomic && !isboxed && rhs.typ == jltype);
@@ -4435,10 +4453,11 @@ static jl_cgval_t emit_new_struct(jl_codectx_t &ctx, jl_value_t *ty, size_t narg
44354453
for (size_t i = nargs; i < nf; i++) {
44364454
if (!jl_field_isptr(sty, i) && jl_is_uniontype(jl_field_type(sty, i))) {
44374455
jl_aliasinfo_t ai = jl_aliasinfo_t::fromTBAA(ctx, strctinfo.tbaa);
4438-
ai.decorateInst(ctx.builder.CreateAlignedStore(
4456+
auto *store = ctx.builder.CreateAlignedStore(
44394457
ConstantInt::get(getInt8Ty(ctx.builder.getContext()), 0),
44404458
emit_ptrgep(ctx, strct, jl_field_offset(sty, i) + jl_field_size(sty, i) - 1),
4441-
Align(1)));
4459+
Align(1));
4460+
ai.decorateInst(store);
44424461
}
44434462
}
44444463
// TODO: verify that nargs <= nf (currently handled by front-end)

0 commit comments

Comments
 (0)