Skip to content

Commit 77572bc

Browse files
timholyaviateskmartinholters
authored
Allow constant-propagation to be disabled (#42125)
Our heuristics for constant propagation are imperfect (and probably never will be perfect), and I've now seen many examples of methods that no developer would ask to have const-propped get that treatment. In some cases the cost for latency/precompilation is very large. This renames `@aggressive_constprop` to `@constprop` and allows two settings, `:aggressive` and `:none`. Closes #38983 Co-authored-by: Shuhei Kadowaki <[email protected]> Co-authored-by: Martin Holters <[email protected]>
1 parent 50ab3a9 commit 77572bc

File tree

13 files changed

+120
-69
lines changed

13 files changed

+120
-69
lines changed

base/char.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ represents a valid Unicode character.
4545
"""
4646
Char
4747

48-
@aggressive_constprop (::Type{T})(x::Number) where {T<:AbstractChar} = T(UInt32(x))
49-
@aggressive_constprop AbstractChar(x::Number) = Char(x)
50-
@aggressive_constprop (::Type{T})(x::AbstractChar) where {T<:Union{Number,AbstractChar}} = T(codepoint(x))
51-
@aggressive_constprop (::Type{T})(x::AbstractChar) where {T<:Union{Int32,Int64}} = codepoint(x) % T
48+
@constprop :aggressive (::Type{T})(x::Number) where {T<:AbstractChar} = T(UInt32(x))
49+
@constprop :aggressive AbstractChar(x::Number) = Char(x)
50+
@constprop :aggressive (::Type{T})(x::AbstractChar) where {T<:Union{Number,AbstractChar}} = T(codepoint(x))
51+
@constprop :aggressive (::Type{T})(x::AbstractChar) where {T<:Union{Int32,Int64}} = codepoint(x) % T
5252
(::Type{T})(x::T) where {T<:AbstractChar} = x
5353

5454
"""
@@ -75,7 +75,7 @@ return a different-sized integer (e.g. `UInt8`).
7575
"""
7676
function codepoint end
7777

78-
@aggressive_constprop codepoint(c::Char) = UInt32(c)
78+
@constprop :aggressive codepoint(c::Char) = UInt32(c)
7979

8080
struct InvalidCharError{T<:AbstractChar} <: Exception
8181
char::T
@@ -124,7 +124,7 @@ See also [`decode_overlong`](@ref) and [`show_invalid`](@ref).
124124
"""
125125
isoverlong(c::AbstractChar) = false
126126

127-
@aggressive_constprop function UInt32(c::Char)
127+
@constprop :aggressive function UInt32(c::Char)
128128
# TODO: use optimized inline LLVM
129129
u = bitcast(UInt32, c)
130130
u < 0x80000000 && return u >> 24
@@ -148,7 +148,7 @@ that support overlong encodings should implement `Base.decode_overlong`.
148148
"""
149149
function decode_overlong end
150150

151-
@aggressive_constprop function decode_overlong(c::Char)
151+
@constprop :aggressive function decode_overlong(c::Char)
152152
u = bitcast(UInt32, c)
153153
l1 = leading_ones(u)
154154
t0 = trailing_zeros(u) & 56
@@ -158,7 +158,7 @@ function decode_overlong end
158158
((u & 0x007f0000) >> 4) | ((u & 0x7f000000) >> 6)
159159
end
160160

161-
@aggressive_constprop function Char(u::UInt32)
161+
@constprop :aggressive function Char(u::UInt32)
162162
u < 0x80 && return bitcast(Char, u << 24)
163163
u < 0x00200000 || throw_code_point_err(u)
164164
c = ((u << 0) & 0x0000003f) | ((u << 2) & 0x00003f00) |
@@ -169,14 +169,14 @@ end
169169
bitcast(Char, c)
170170
end
171171

172-
@aggressive_constprop @noinline UInt32_cold(c::Char) = UInt32(c)
173-
@aggressive_constprop function (T::Union{Type{Int8},Type{UInt8}})(c::Char)
172+
@constprop :aggressive @noinline UInt32_cold(c::Char) = UInt32(c)
173+
@constprop :aggressive function (T::Union{Type{Int8},Type{UInt8}})(c::Char)
174174
i = bitcast(Int32, c)
175175
i 0 ? ((i >>> 24) % T) : T(UInt32_cold(c))
176176
end
177177

178-
@aggressive_constprop @noinline Char_cold(b::UInt32) = Char(b)
179-
@aggressive_constprop function Char(b::Union{Int8,UInt8})
178+
@constprop :aggressive @noinline Char_cold(b::UInt32) = Char(b)
179+
@constprop :aggressive function Char(b::Union{Int8,UInt8})
180180
0 b 0x7f ? bitcast(Char, (b % UInt32) << 24) : Char_cold(UInt32(b))
181181
end
182182

base/compiler/abstractinterpretation.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,10 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me
572572
return nothing
573573
end
574574
method = match.method
575+
if method.constprop == 0x02
576+
add_remark!(interp, sv, "[constprop] Disabled by method parameter")
577+
return nothing
578+
end
575579
force = force_const_prop(interp, f, method)
576580
force || const_prop_entry_heuristic(interp, result, sv) || return nothing
577581
nargs::Int = method.nargs
@@ -653,7 +657,7 @@ function is_allconst(argtypes::Vector{Any})
653657
end
654658

655659
function force_const_prop(interp::AbstractInterpreter, @nospecialize(f), method::Method)
656-
return method.aggressive_constprop ||
660+
return method.constprop == 0x01 ||
657661
InferenceParams(interp).aggressive_constant_propagation ||
658662
istopfunction(f, :getproperty) ||
659663
istopfunction(f, :setproperty!)

base/expr.jl

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -349,16 +349,26 @@ macro pure(ex)
349349
end
350350

351351
"""
352-
@aggressive_constprop ex
353-
@aggressive_constprop(ex)
354-
355-
`@aggressive_constprop` requests more aggressive interprocedural constant
356-
propagation for the annotated function. For a method where the return type
357-
depends on the value of the arguments, this can yield improved inference results
358-
at the cost of additional compile time.
359-
"""
360-
macro aggressive_constprop(ex)
361-
esc(isa(ex, Expr) ? pushmeta!(ex, :aggressive_constprop) : ex)
352+
@constprop setting ex
353+
@constprop(setting, ex)
354+
355+
`@constprop` controls the mode of interprocedural constant propagation for the
356+
annotated function. Two `setting`s are supported:
357+
358+
- `@constprop :aggressive ex`: apply constant propagation aggressively.
359+
For a method where the return type depends on the value of the arguments,
360+
this can yield improved inference results at the cost of additional compile time.
361+
- `@constprop :none ex`: disable constant propagation. This can reduce compile
362+
times for functions that Julia might otherwise deem worthy of constant-propagation.
363+
Common cases are for functions with `Bool`- or `Symbol`-valued arguments or keyword arguments.
364+
"""
365+
macro constprop(setting, ex)
366+
if isa(setting, QuoteNode)
367+
setting = setting.value
368+
end
369+
setting === :aggressive && return esc(isa(ex, Expr) ? pushmeta!(ex, :aggressive_constprop) : ex)
370+
setting === :none && return esc(isa(ex, Expr) ? pushmeta!(ex, :no_constprop) : ex)
371+
throw(ArgumentError("@constprop $setting not supported"))
362372
end
363373

364374
"""

src/ast.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ jl_sym_t *static_parameter_sym; jl_sym_t *inline_sym;
5959
jl_sym_t *noinline_sym; jl_sym_t *generated_sym;
6060
jl_sym_t *generated_only_sym; jl_sym_t *isdefined_sym;
6161
jl_sym_t *propagate_inbounds_sym; jl_sym_t *specialize_sym;
62-
jl_sym_t *aggressive_constprop_sym;
62+
jl_sym_t *aggressive_constprop_sym; jl_sym_t *no_constprop_sym;
6363
jl_sym_t *nospecialize_sym; jl_sym_t *macrocall_sym;
6464
jl_sym_t *colon_sym; jl_sym_t *hygienicscope_sym;
6565
jl_sym_t *throw_undef_if_not_sym; jl_sym_t *getfield_undefref_sym;
@@ -399,6 +399,7 @@ void jl_init_common_symbols(void)
399399
polly_sym = jl_symbol("polly");
400400
propagate_inbounds_sym = jl_symbol("propagate_inbounds");
401401
aggressive_constprop_sym = jl_symbol("aggressive_constprop");
402+
no_constprop_sym = jl_symbol("no_constprop");
402403
isdefined_sym = jl_symbol("isdefined");
403404
nospecialize_sym = jl_symbol("nospecialize");
404405
specialize_sym = jl_symbol("specialize");

src/dump.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li
671671
write_int8(s->s, m->isva);
672672
write_int8(s->s, m->pure);
673673
write_int8(s->s, m->is_for_opaque_closure);
674-
write_int8(s->s, m->aggressive_constprop);
674+
write_int8(s->s, m->constprop);
675675
jl_serialize_value(s, (jl_value_t*)m->slot_syms);
676676
jl_serialize_value(s, (jl_value_t*)m->roots);
677677
jl_serialize_value(s, (jl_value_t*)m->ccallable);
@@ -1525,7 +1525,7 @@ static jl_value_t *jl_deserialize_value_method(jl_serializer_state *s, jl_value_
15251525
m->isva = read_int8(s->s);
15261526
m->pure = read_int8(s->s);
15271527
m->is_for_opaque_closure = read_int8(s->s);
1528-
m->aggressive_constprop = read_int8(s->s);
1528+
m->constprop = read_int8(s->s);
15291529
m->slot_syms = jl_deserialize_value(s, (jl_value_t**)&m->slot_syms);
15301530
jl_gc_wb(m, m->slot_syms);
15311531
m->roots = (jl_array_t*)jl_deserialize_value(s, (jl_value_t**)&m->roots);

src/ircode.c

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,17 @@ static void jl_encode_value_(jl_ircode_state *s, jl_value_t *v, int as_literal)
381381
}
382382
}
383383

384+
static jl_code_info_flags_t code_info_flags(uint8_t pure, uint8_t propagate_inbounds, uint8_t inlineable, uint8_t inferred, uint8_t constprop)
385+
{
386+
jl_code_info_flags_t flags;
387+
flags.bits.pure = pure;
388+
flags.bits.propagate_inbounds = propagate_inbounds;
389+
flags.bits.inlineable = inlineable;
390+
flags.bits.inferred = inferred;
391+
flags.bits.constprop = constprop;
392+
return flags;
393+
}
394+
384395
// --- decoding ---
385396

386397
static jl_value_t *jl_decode_value(jl_ircode_state *s) JL_GC_DISABLED;
@@ -702,12 +713,8 @@ JL_DLLEXPORT jl_array_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code)
702713
jl_current_task->ptls
703714
};
704715

705-
uint8_t flags = (code->aggressive_constprop << 4)
706-
| (code->inferred << 3)
707-
| (code->inlineable << 2)
708-
| (code->propagate_inbounds << 1)
709-
| (code->pure << 0);
710-
write_uint8(s.s, flags);
716+
jl_code_info_flags_t flags = code_info_flags(code->pure, code->propagate_inbounds, code->inlineable, code->inferred, code->constprop);
717+
write_uint8(s.s, flags.packed);
711718

712719
size_t nslots = jl_array_len(code->slotflags);
713720
assert(nslots >= m->nargs && nslots < INT32_MAX); // required by generated functions
@@ -787,12 +794,13 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t
787794
};
788795

789796
jl_code_info_t *code = jl_new_code_info_uninit();
790-
uint8_t flags = read_uint8(s.s);
791-
code->aggressive_constprop = !!(flags & (1 << 4));
792-
code->inferred = !!(flags & (1 << 3));
793-
code->inlineable = !!(flags & (1 << 2));
794-
code->propagate_inbounds = !!(flags & (1 << 1));
795-
code->pure = !!(flags & (1 << 0));
797+
jl_code_info_flags_t flags;
798+
flags.packed = read_uint8(s.s);
799+
code->constprop = flags.bits.constprop;
800+
code->inferred = flags.bits.inferred;
801+
code->inlineable = flags.bits.inlineable;
802+
code->propagate_inbounds = flags.bits.propagate_inbounds;
803+
code->pure = flags.bits.pure;
796804

797805
size_t nslots = read_int32(&src);
798806
code->slotflags = jl_alloc_array_1d(jl_array_uint8_type, nslots);
@@ -847,26 +855,29 @@ JL_DLLEXPORT uint8_t jl_ir_flag_inferred(jl_array_t *data)
847855
if (jl_is_code_info(data))
848856
return ((jl_code_info_t*)data)->inferred;
849857
assert(jl_typeis(data, jl_array_uint8_type));
850-
uint8_t flags = ((uint8_t*)data->data)[0];
851-
return !!(flags & (1 << 3));
858+
jl_code_info_flags_t flags;
859+
flags.packed = ((uint8_t*)data->data)[0];
860+
return flags.bits.inferred;
852861
}
853862

854863
JL_DLLEXPORT uint8_t jl_ir_flag_inlineable(jl_array_t *data)
855864
{
856865
if (jl_is_code_info(data))
857866
return ((jl_code_info_t*)data)->inlineable;
858867
assert(jl_typeis(data, jl_array_uint8_type));
859-
uint8_t flags = ((uint8_t*)data->data)[0];
860-
return !!(flags & (1 << 2));
868+
jl_code_info_flags_t flags;
869+
flags.packed = ((uint8_t*)data->data)[0];
870+
return flags.bits.inlineable;
861871
}
862872

863873
JL_DLLEXPORT uint8_t jl_ir_flag_pure(jl_array_t *data)
864874
{
865875
if (jl_is_code_info(data))
866876
return ((jl_code_info_t*)data)->pure;
867877
assert(jl_typeis(data, jl_array_uint8_type));
868-
uint8_t flags = ((uint8_t*)data->data)[0];
869-
return !!(flags & (1 << 0));
878+
jl_code_info_flags_t flags;
879+
flags.packed = ((uint8_t*)data->data)[0];
880+
return flags.bits.pure;
870881
}
871882

872883
JL_DLLEXPORT jl_value_t *jl_compress_argnames(jl_array_t *syms)

src/jltypes.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2348,7 +2348,7 @@ void jl_init_types(void) JL_GC_DISABLED
23482348
"inlineable",
23492349
"propagate_inbounds",
23502350
"pure",
2351-
"aggressive_constprop"),
2351+
"constprop"),
23522352
jl_svec(19,
23532353
jl_array_any_type,
23542354
jl_array_int32_type,
@@ -2368,7 +2368,7 @@ void jl_init_types(void) JL_GC_DISABLED
23682368
jl_bool_type,
23692369
jl_bool_type,
23702370
jl_bool_type,
2371-
jl_bool_type),
2371+
jl_uint8_type),
23722372
jl_emptysvec,
23732373
0, 1, 19);
23742374

@@ -2401,7 +2401,7 @@ void jl_init_types(void) JL_GC_DISABLED
24012401
"isva",
24022402
"pure",
24032403
"is_for_opaque_closure",
2404-
"aggressive_constprop"),
2404+
"constprop"),
24052405
jl_svec(26,
24062406
jl_symbol_type,
24072407
jl_module_type,
@@ -2428,7 +2428,7 @@ void jl_init_types(void) JL_GC_DISABLED
24282428
jl_bool_type,
24292429
jl_bool_type,
24302430
jl_bool_type,
2431-
jl_bool_type),
2431+
jl_uint8_type),
24322432
jl_emptysvec,
24332433
0, 1, 10);
24342434

src/julia.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,8 @@ typedef struct _jl_code_info_t {
278278
uint8_t inlineable;
279279
uint8_t propagate_inbounds;
280280
uint8_t pure;
281-
uint8_t aggressive_constprop;
281+
// uint8 settings
282+
uint8_t constprop; // 0 = use heuristic; 1 = aggressive; 2 = none
282283
} jl_code_info_t;
283284

284285
// This type describes a single method definition, and stores data
@@ -326,7 +327,8 @@ typedef struct _jl_method_t {
326327
uint8_t isva;
327328
uint8_t pure;
328329
uint8_t is_for_opaque_closure;
329-
uint8_t aggressive_constprop;
330+
// uint8 settings
331+
uint8_t constprop; // 0x00 = use heuristic; 0x01 = aggressive; 0x02 = none
330332

331333
// hidden fields:
332334
// lock for modifications to the method

src/julia_internal.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,9 +475,24 @@ STATIC_INLINE jl_value_t *undefref_check(jl_datatype_t *dt, jl_value_t *v) JL_NO
475475
return v;
476476
}
477477

478+
// -- helper types -- //
479+
480+
typedef struct {
481+
uint8_t pure:1;
482+
uint8_t propagate_inbounds:1;
483+
uint8_t inlineable:1;
484+
uint8_t inferred:1;
485+
uint8_t constprop:2; // 0 = use heuristic; 1 = aggressive; 2 = none
486+
} jl_code_info_flags_bitfield_t;
487+
488+
typedef union {
489+
jl_code_info_flags_bitfield_t bits;
490+
uint8_t packed;
491+
} jl_code_info_flags_t;
478492

479493
// -- functions -- //
480494

495+
// jl_code_info_flag_t code_info_flags(uint8_t pure, uint8_t propagate_inbounds, uint8_t inlineable, uint8_t inferred, uint8_t constprop);
481496
jl_code_info_t *jl_type_infer(jl_method_instance_t *li, size_t world, int force);
482497
jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *meth JL_PROPAGATES_ROOT, size_t world);
483498
jl_code_instance_t *jl_generate_fptr(jl_method_instance_t *mi JL_PROPAGATES_ROOT, size_t world);
@@ -1376,7 +1391,7 @@ extern jl_sym_t *static_parameter_sym; extern jl_sym_t *inline_sym;
13761391
extern jl_sym_t *noinline_sym; extern jl_sym_t *generated_sym;
13771392
extern jl_sym_t *generated_only_sym; extern jl_sym_t *isdefined_sym;
13781393
extern jl_sym_t *propagate_inbounds_sym; extern jl_sym_t *specialize_sym;
1379-
extern jl_sym_t *aggressive_constprop_sym;
1394+
extern jl_sym_t *aggressive_constprop_sym; extern jl_sym_t *no_constprop_sym;
13801395
extern jl_sym_t *nospecialize_sym; extern jl_sym_t *macrocall_sym;
13811396
extern jl_sym_t *colon_sym; extern jl_sym_t *hygienicscope_sym;
13821397
extern jl_sym_t *throw_undef_if_not_sym; extern jl_sym_t *getfield_undefref_sym;

src/method.c

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,9 @@ static void jl_code_info_set_ir(jl_code_info_t *li, jl_expr_t *ir)
303303
else if (ma == (jl_value_t*)propagate_inbounds_sym)
304304
li->propagate_inbounds = 1;
305305
else if (ma == (jl_value_t*)aggressive_constprop_sym)
306-
li->aggressive_constprop = 1;
306+
li->constprop = 1;
307+
else if (ma == (jl_value_t*)no_constprop_sym)
308+
li->constprop = 2;
307309
else
308310
jl_array_ptr_set(meta, ins++, ma);
309311
}
@@ -443,7 +445,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void)
443445
src->propagate_inbounds = 0;
444446
src->pure = 0;
445447
src->edges = jl_nothing;
446-
src->aggressive_constprop = 0;
448+
src->constprop = 0;
447449
return src;
448450
}
449451

@@ -630,7 +632,7 @@ static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src)
630632
}
631633
m->called = called;
632634
m->pure = src->pure;
633-
m->aggressive_constprop = src->aggressive_constprop;
635+
m->constprop = src->constprop;
634636
jl_add_function_name_to_lineinfo(src, (jl_value_t*)m->name);
635637

636638
jl_array_t *copy = NULL;
@@ -746,7 +748,7 @@ JL_DLLEXPORT jl_method_t *jl_new_method_uninit(jl_module_t *module)
746748
m->primary_world = 1;
747749
m->deleted_world = ~(size_t)0;
748750
m->is_for_opaque_closure = 0;
749-
m->aggressive_constprop = 0;
751+
m->constprop = 0;
750752
JL_MUTEX_INIT(&m->writelock);
751753
return m;
752754
}

0 commit comments

Comments
 (0)