@@ -89,6 +89,122 @@ function has_ancestor(query::Module, target::Module)
89
89
end
90
90
end
91
91
92
+ const __skip_rewrite_func_set_lock = ReentrantLock ()
93
+ const __skip_rewrite_func_set = Set ([
94
+ # Avoid the 1.10 stackoverflow
95
+ typeof (Base. typed_hvcat),
96
+ typeof (Base. hvcat),
97
+ typeof (Core. Compiler. concrete_eval_eligible),
98
+ typeof (Core. Compiler. typeinf_type),
99
+ typeof (Core. Compiler. typeinf_ext),
100
+ # TODO : perhaps problematic calls in `traced_call`
101
+ # should be moved to TracedUtils.jl:
102
+ typeof (Reactant. ReactantCore. traced_call),
103
+ typeof (ReactantCore. is_traced),
104
+ # Perf optimization
105
+ typeof (Base. typemax),
106
+ typeof (Base. typemin),
107
+ typeof (Base. getproperty),
108
+ typeof (Base. vect),
109
+ typeof (Base. eltype),
110
+ typeof (Base. argtail),
111
+ typeof (Base. identity),
112
+ typeof (Base. print),
113
+ typeof (Base. println),
114
+ typeof (Base. show),
115
+ typeof (Base. show_delim_array),
116
+ typeof (Base. sprint),
117
+ typeof (Adapt. adapt_structure),
118
+ typeof (Core. is_top_bit_set),
119
+ typeof (Base. setindex_widen_up_to),
120
+ typeof (Base. typejoin),
121
+ typeof (Base. argtype_decl),
122
+ typeof (Base. arg_decl_parts),
123
+ typeof (Base. StackTraces. show_spec_sig),
124
+ typeof (Core. Compiler. return_type),
125
+ typeof (Core. throw_inexacterror),
126
+ typeof (Base. throw_boundserror),
127
+ typeof (Base. _shrink),
128
+ typeof (Base. _shrink!),
129
+ typeof (Base. ht_keyindex),
130
+ typeof (Base. checkindex),
131
+ typeof (Base. to_index),
132
+ @static (
133
+ if VERSION >= v " 1.11.0"
134
+ typeof (Base. memoryref)
135
+ end
136
+ ),
137
+ typeof (Reactant. materialize_traced_array),
138
+ ])
139
+
140
+ """
141
+ @skip_rewrite_func f
142
+
143
+ Mark function `f` so that Reactant's IR rewrite mechanism will skip it.
144
+ This can improve compilation time if it's safe to assume that no call inside `f`
145
+ will need a `@reactant_overlay` method.
146
+
147
+ !!! info
148
+ Note that this marks the whole function, not a specific method with a type
149
+ signature.
150
+
151
+ !!! warning
152
+ The macro call should be inside the `__init__` function. If you want to
153
+ mark it for precompilation, you must add the macro call in the global scope
154
+ too.
155
+
156
+ See also: [`@skip_rewrite_type`](@ref)
157
+ """
158
+ macro skip_rewrite_func (fname)
159
+ quote
160
+ @lock $ (Reactant. __skip_rewrite_func_set_lock) push! (
161
+ $ (Reactant. __skip_rewrite_func_set), typeof ($ (esc (fname)))
162
+ )
163
+ end
164
+ end
165
+
166
+ const __skip_rewrite_type_constructor_list_lock = ReentrantLock ()
167
+ const __skip_rewrite_type_constructor_list = [
168
+ # Don't rewrite Val
169
+ Type{Base. Val},
170
+ # Don't rewrite exception constructors
171
+ Type{<: Core.Exception },
172
+ # Don't rewrite traced constructors
173
+ Type{<: TracedRArray },
174
+ Type{<: TracedRNumber },
175
+ Type{MLIR. IR. Location},
176
+ Type{MLIR. IR. Block},
177
+ ]
178
+
179
+ """
180
+ @skip_rewrite_type MyStruct
181
+ @skip_rewrite_type Type{<:MyStruct}
182
+
183
+ Mark the construct function of `MyStruct` so that Reactant's IR rewrite mechanism
184
+ will skip it. It does the same as [`@skip_rewrite_func`](@ref) but for type
185
+ constructors.
186
+
187
+ If you want to mark the set of constructors over it's type parameters or over its
188
+ abstract type, you should use then the `Type{<:MyStruct}` syntax.
189
+
190
+ !!! warning
191
+ The macro call should be inside the `__init__` function. If you want to
192
+ mark it for precompilation, you must add the macro call in the global scope
193
+ too.
194
+ """
195
+ macro skip_rewrite_type (typ)
196
+ typ = if Base. isexpr (typ, :curly ) && typ. args[1 ] === :Type
197
+ typ
198
+ else
199
+ Expr (:curly , :Type , typ)
200
+ end
201
+ return quote
202
+ @lock $ (Reactant. __skip_rewrite_type_constructor_list_lock) push! (
203
+ $ (Reactant. __skip_rewrite_type_constructor_list), $ (esc (typ))
204
+ )
205
+ end
206
+ end
207
+
92
208
function should_rewrite_call (@nospecialize (ft))
93
209
# Don't rewrite builtin or intrinsics
94
210
if ft <: Core.IntrinsicFunction || ft <: Core.Builtin
@@ -123,66 +239,13 @@ function should_rewrite_call(@nospecialize(ft))
123
239
end
124
240
end
125
241
end
126
- # Don't rewrite Val
127
- if ft === Type{Base. Val}
128
- return false
129
- end
130
- # Don't rewrite exception constructors
131
- if ft <: Type{<:Core.Exception}
132
- return false
133
- end
134
-
135
- # Avoid the 1.10 stackoverflow
136
- if ft <: typeof (Base. typed_hvcat)
137
- return false
138
- end
139
- if ft <: typeof (Base. hvcat)
140
- return false
141
- end
142
- if ft <: typeof (Core. Compiler. concrete_eval_eligible)
143
- return false
144
- end
145
- if ft <: typeof (Core. Compiler. typeinf_type) || ft <: typeof (Core. Compiler. typeinf_ext)
146
- return false
147
- end
148
-
149
- # Don't rewrite traced constructors
150
- if ft <: Type{<:TracedRArray} ||
151
- ft <: Type{<:TracedRNumber} ||
152
- ft === Type{MLIR. IR. Location} ||
153
- ft === Type{MLIR. IR. Block} ||
154
- # TODO : perhaps problematic calls in `traced_call`
155
- # should be moved to TracedUtils.jl:
156
- ft <: typeof (Reactant. ReactantCore. traced_call) ||
157
- ft <: typeof (ReactantCore. is_traced)
158
- return false
159
- end
160
242
161
- # Perf optimizations
162
- if ft <: typeof (Core . Compiler . return_type )
243
+ # `ft isa Type` is for performance as it avoids checking against all the list, but can be removed if problematic
244
+ if ft isa Type && any (t -> ft <: t , __skip_rewrite_type_constructor_list )
163
245
return false
164
246
end
165
247
166
- # Perf optimizations
167
- if ft <: typeof (Base. typemax) ||
168
- ft <: typeof (Base. typemin) ||
169
- ft <: typeof (Base. getproperty) ||
170
- ft <: typeof (Base. vect) ||
171
- ft <: typeof (Base. eltype) ||
172
- ft <: typeof (Base. argtail) ||
173
- ft <: typeof (Base. identity) ||
174
- ft <: typeof (Base. print) ||
175
- ft <: typeof (Base. println) ||
176
- ft <: typeof (Base. show) ||
177
- ft <: typeof (Base. show_delim_array) ||
178
- ft <: typeof (Base. sprint) ||
179
- ft <: typeof (Adapt. adapt_structure) ||
180
- ft <: typeof (Core. is_top_bit_set) ||
181
- ft <: typeof (Base. setindex_widen_up_to) ||
182
- ft <: typeof (Base. typejoin) ||
183
- ft <: typeof (Base. argtype_decl) ||
184
- ft <: typeof (Base. arg_decl_parts) ||
185
- ft <: typeof (Base. StackTraces. show_spec_sig)
248
+ if ft in __skip_rewrite_func_set
186
249
return false
187
250
end
188
251
192
255
193
256
# by default, same as `should_rewrite_call`
194
257
function should_rewrite_invoke (@nospecialize (ft), @nospecialize (args))
258
+ # TODO how can we extend `@skip_rewrite` to methods?
195
259
if ft <: typeof (repeat) && (args == Tuple{String,Int64} || args == Tuple{Char,Int64})
196
260
return false
197
261
end
0 commit comments