Skip to content

Commit 8b40632

Browse files
committed
Merge remote-tracking branch 'refs/remotes/origin/master'
2 parents afa2903 + 1fa67c9 commit 8b40632

File tree

1 file changed

+28
-10
lines changed

1 file changed

+28
-10
lines changed

src/build_function.jl

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,11 @@ function _build_function(target::JuliaTarget, op, args...;
113113
expr = if cse
114114
fun = Func(dargs, [], Code.cse(unwrap(op)))
115115
(wrap_code !== nothing) && (fun = wrap_code(fun))
116-
toexpr(fun, states)
116+
conv(fun, states)
117117
else
118118
fun = Func(dargs, [], op)
119119
(wrap_code !== nothing) && (fun = wrap_code(fun))
120-
toexpr(fun, states)
120+
conv(fun, states)
121121
end
122122

123123
if expression == Val{true}
@@ -142,14 +142,14 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...;
142142
Symbol("ˍ₋arg$(x[1])")), enumerate([args...]))
143143

144144
expr = if cse
145-
toexpr(Func(dargs, [], Code.cse(unwrap(op))), states)
145+
conv(Func(dargs, [], Code.cse(unwrap(op))), states)
146146
else
147-
toexpr(Func(dargs, [], op), states)
147+
conv(Func(dargs, [], op), states)
148148
end
149149

150150
outsym = Symbol("ˍ₋out")
151151
body = inplace_expr(unwrap(op), outsym)
152-
oop_expr = toexpr(Func([outsym, dargs...], [], body), states)
152+
oop_expr = conv(Func([outsym, dargs...], [], body), states)
153153

154154
N = length(shape(op))
155155
op = unwrap(op)
@@ -161,7 +161,7 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...;
161161
$outsym
162162
end) |> LiteralExpr
163163
end
164-
ip_expr = toexpr(Func(dargs, [], op_body), states)
164+
ip_expr = conv(Func(dargs, [], op_body), states)
165165
if expression == Val{true}
166166
oop_expr, ip_expr
167167
else
@@ -194,11 +194,28 @@ function fill_array_with_zero!(x::AbstractArray)
194194
end
195195

196196
"""
197+
_build_function(target::JuliaTarget, rhss::AbstractArray, args...;
198+
conv=toexpr,
199+
expression = Val{true},
200+
expression_module = @__MODULE__(),
201+
checkbounds = false,
202+
postprocess_fbody=ex -> ex,
203+
linenumbers = false,
204+
outputidxs=nothing,
205+
skipzeros = false,
206+
force_SA = false,
207+
wrap_code = (nothing, nothing),
208+
fillzeros = skipzeros && !(rhss isa SparseMatrixCSC),
209+
states = LazyState(),
210+
iip_config = (true, true),
211+
parallel=nothing, cse = false, kwargs...)
212+
197213
Build function target: `JuliaTarget`
198214
199215
```julia
200216
function _build_function(target::JuliaTarget, rhss, args...;
201-
conv = toexpr, expression = Val{true},
217+
conv = toexpr,
218+
expression = Val{true},
202219
checkbounds = false,
203220
linenumbers = false,
204221
headerfun = addheader, outputidxs=nothing,
@@ -250,6 +267,7 @@ Special Keyword Arguments:
250267
safety with `skipzeros`.
251268
"""
252269
function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
270+
conv=toexpr,
253271
expression = Val{true},
254272
expression_module = @__MODULE__(),
255273
checkbounds = false,
@@ -303,10 +321,10 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
303321
end
304322

305323
if expression == Val{true}
306-
return toexpr(oop_expr, states), toexpr(ip_expr, states)
324+
return conv(oop_expr, states), conv(ip_expr, states)
307325
else
308-
return _build_and_inject_function(expression_module, toexpr(oop_expr, states)),
309-
_build_and_inject_function(expression_module, toexpr(ip_expr, states))
326+
return _build_and_inject_function(expression_module, conv(oop_expr, states)),
327+
_build_and_inject_function(expression_module, conv(ip_expr, states))
310328
end
311329
end
312330

0 commit comments

Comments
 (0)