Skip to content

Commit a65da93

Browse files
authored
Introduce @skip_rewrite_func, @skip_rewrite_type to extend should_rewrite_call (#1377)
* Introduce `@skip_rewrite` to extend `should_rewrite_call` * Split `@skip_rewrite` into `@skip_rewrite_func` and `@skip_rewrite_type` * Try fix `Base.memoryref` symbol in Julia 1.10 * Add admonitions to docstrings * Test `@skip_rewrite_func` * Comment test * Fix test * Add locks around `__skip_rewrite_func_set`, `__skip_rewrite_type_constructor_list` * Add `@skip_rewrite_*` macros to docs
1 parent 632964e commit a65da93

File tree

3 files changed

+148
-56
lines changed

3 files changed

+148
-56
lines changed

docs/src/api/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ within_compile
2929
@code_xla
3030
```
3131

32+
## Tracing customization
33+
34+
```@docs
35+
Reactant.@skip_rewrite_func
36+
Reactant.@skip_rewrite_type
37+
```
38+
3239
## Profile XLA
3340

3441
Reactant can hook into XLA's profiler to generate compilation and execution traces.

src/utils.jl

Lines changed: 120 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,122 @@ function has_ancestor(query::Module, target::Module)
8989
end
9090
end
9191

92+
const __skip_rewrite_func_set_lock = ReentrantLock()
93+
const __skip_rewrite_func_set = Set([
94+
# Avoid the 1.10 stackoverflow
95+
typeof(Base.typed_hvcat),
96+
typeof(Base.hvcat),
97+
typeof(Core.Compiler.concrete_eval_eligible),
98+
typeof(Core.Compiler.typeinf_type),
99+
typeof(Core.Compiler.typeinf_ext),
100+
# TODO: perhaps problematic calls in `traced_call`
101+
# should be moved to TracedUtils.jl:
102+
typeof(Reactant.ReactantCore.traced_call),
103+
typeof(ReactantCore.is_traced),
104+
# Perf optimization
105+
typeof(Base.typemax),
106+
typeof(Base.typemin),
107+
typeof(Base.getproperty),
108+
typeof(Base.vect),
109+
typeof(Base.eltype),
110+
typeof(Base.argtail),
111+
typeof(Base.identity),
112+
typeof(Base.print),
113+
typeof(Base.println),
114+
typeof(Base.show),
115+
typeof(Base.show_delim_array),
116+
typeof(Base.sprint),
117+
typeof(Adapt.adapt_structure),
118+
typeof(Core.is_top_bit_set),
119+
typeof(Base.setindex_widen_up_to),
120+
typeof(Base.typejoin),
121+
typeof(Base.argtype_decl),
122+
typeof(Base.arg_decl_parts),
123+
typeof(Base.StackTraces.show_spec_sig),
124+
typeof(Core.Compiler.return_type),
125+
typeof(Core.throw_inexacterror),
126+
typeof(Base.throw_boundserror),
127+
typeof(Base._shrink),
128+
typeof(Base._shrink!),
129+
typeof(Base.ht_keyindex),
130+
typeof(Base.checkindex),
131+
typeof(Base.to_index),
132+
@static(
133+
if VERSION >= v"1.11.0"
134+
typeof(Base.memoryref)
135+
end
136+
),
137+
typeof(Reactant.materialize_traced_array),
138+
])
139+
140+
"""
141+
@skip_rewrite_func f
142+
143+
Mark function `f` so that Reactant's IR rewrite mechanism will skip it.
144+
This can improve compilation time if it's safe to assume that no call inside `f`
145+
will need a `@reactant_overlay` method.
146+
147+
!!! info
148+
Note that this marks the whole function, not a specific method with a type
149+
signature.
150+
151+
!!! warning
152+
The macro call should be inside the `__init__` function. If you want to
153+
mark it for precompilation, you must add the macro call in the global scope
154+
too.
155+
156+
See also: [`@skip_rewrite_type`](@ref)
157+
"""
158+
macro skip_rewrite_func(fname)
159+
quote
160+
@lock $(Reactant.__skip_rewrite_func_set_lock) push!(
161+
$(Reactant.__skip_rewrite_func_set), typeof($(esc(fname)))
162+
)
163+
end
164+
end
165+
166+
const __skip_rewrite_type_constructor_list_lock = ReentrantLock()
167+
const __skip_rewrite_type_constructor_list = [
168+
# Don't rewrite Val
169+
Type{Base.Val},
170+
# Don't rewrite exception constructors
171+
Type{<:Core.Exception},
172+
# Don't rewrite traced constructors
173+
Type{<:TracedRArray},
174+
Type{<:TracedRNumber},
175+
Type{MLIR.IR.Location},
176+
Type{MLIR.IR.Block},
177+
]
178+
179+
"""
180+
@skip_rewrite_type MyStruct
181+
@skip_rewrite_type Type{<:MyStruct}
182+
183+
Mark the construct function of `MyStruct` so that Reactant's IR rewrite mechanism
184+
will skip it. It does the same as [`@skip_rewrite_func`](@ref) but for type
185+
constructors.
186+
187+
If you want to mark the set of constructors over it's type parameters or over its
188+
abstract type, you should use then the `Type{<:MyStruct}` syntax.
189+
190+
!!! warning
191+
The macro call should be inside the `__init__` function. If you want to
192+
mark it for precompilation, you must add the macro call in the global scope
193+
too.
194+
"""
195+
macro skip_rewrite_type(typ)
196+
typ = if Base.isexpr(typ, :curly) && typ.args[1] === :Type
197+
typ
198+
else
199+
Expr(:curly, :Type, typ)
200+
end
201+
return quote
202+
@lock $(Reactant.__skip_rewrite_type_constructor_list_lock) push!(
203+
$(Reactant.__skip_rewrite_type_constructor_list), $(esc(typ))
204+
)
205+
end
206+
end
207+
92208
function should_rewrite_call(@nospecialize(ft))
93209
# Don't rewrite builtin or intrinsics
94210
if ft <: Core.IntrinsicFunction || ft <: Core.Builtin
@@ -123,66 +239,13 @@ function should_rewrite_call(@nospecialize(ft))
123239
end
124240
end
125241
end
126-
# Don't rewrite Val
127-
if ft === Type{Base.Val}
128-
return false
129-
end
130-
# Don't rewrite exception constructors
131-
if ft <: Type{<:Core.Exception}
132-
return false
133-
end
134-
135-
# Avoid the 1.10 stackoverflow
136-
if ft <: typeof(Base.typed_hvcat)
137-
return false
138-
end
139-
if ft <: typeof(Base.hvcat)
140-
return false
141-
end
142-
if ft <: typeof(Core.Compiler.concrete_eval_eligible)
143-
return false
144-
end
145-
if ft <: typeof(Core.Compiler.typeinf_type) || ft <: typeof(Core.Compiler.typeinf_ext)
146-
return false
147-
end
148-
149-
# Don't rewrite traced constructors
150-
if ft <: Type{<:TracedRArray} ||
151-
ft <: Type{<:TracedRNumber} ||
152-
ft === Type{MLIR.IR.Location} ||
153-
ft === Type{MLIR.IR.Block} ||
154-
# TODO: perhaps problematic calls in `traced_call`
155-
# should be moved to TracedUtils.jl:
156-
ft <: typeof(Reactant.ReactantCore.traced_call) ||
157-
ft <: typeof(ReactantCore.is_traced)
158-
return false
159-
end
160242

161-
# Perf optimizations
162-
if ft <: typeof(Core.Compiler.return_type)
243+
# `ft isa Type` is for performance as it avoids checking against all the list, but can be removed if problematic
244+
if ft isa Type && any(t -> ft <: t, __skip_rewrite_type_constructor_list)
163245
return false
164246
end
165247

166-
# Perf optimizations
167-
if ft <: typeof(Base.typemax) ||
168-
ft <: typeof(Base.typemin) ||
169-
ft <: typeof(Base.getproperty) ||
170-
ft <: typeof(Base.vect) ||
171-
ft <: typeof(Base.eltype) ||
172-
ft <: typeof(Base.argtail) ||
173-
ft <: typeof(Base.identity) ||
174-
ft <: typeof(Base.print) ||
175-
ft <: typeof(Base.println) ||
176-
ft <: typeof(Base.show) ||
177-
ft <: typeof(Base.show_delim_array) ||
178-
ft <: typeof(Base.sprint) ||
179-
ft <: typeof(Adapt.adapt_structure) ||
180-
ft <: typeof(Core.is_top_bit_set) ||
181-
ft <: typeof(Base.setindex_widen_up_to) ||
182-
ft <: typeof(Base.typejoin) ||
183-
ft <: typeof(Base.argtype_decl) ||
184-
ft <: typeof(Base.arg_decl_parts) ||
185-
ft <: typeof(Base.StackTraces.show_spec_sig)
248+
if ft in __skip_rewrite_func_set
186249
return false
187250
end
188251

@@ -192,6 +255,7 @@ end
192255

193256
# by default, same as `should_rewrite_call`
194257
function should_rewrite_invoke(@nospecialize(ft), @nospecialize(args))
258+
# TODO how can we extend `@skip_rewrite` to methods?
195259
if ft <: typeof(repeat) && (args == Tuple{String,Int64} || args == Tuple{Char,Int64})
196260
return false
197261
end

test/tracing.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,4 +286,25 @@ end
286286
@test opt_traced.epsilon isa ConcreteRNumber{Float64}
287287
@test opt_traced.centred isa Bool
288288
end
289+
290+
@testset "@skip_rewrite_func" begin
291+
a = ConcreteRArray([1.0 2.0; 3.0 4.0])
292+
293+
# TODO we should test it with a type-unstable method
294+
add_skip_rewrite(x) = x + x
295+
Reactant.@skip_rewrite_func add_skip_rewrite
296+
297+
# wrapper because `@skip_rewrite_*` doesn't work with top-functions
298+
f(x) = add_skip_rewrite(x)
299+
300+
# warmup
301+
@code_hlo optimize = false f(a)
302+
303+
t = @timed @code_hlo optimize = false f(a)
304+
305+
# `@timed` only measures compile time from v1.11.0 onward
306+
@static if VERSION >= v"1.11.0"
307+
@test iszero(t.compile_time)
308+
end
309+
end
289310
end

0 commit comments

Comments
 (0)