Skip to content

Commit c2e80f2

Browse files
authored
follow up the @generated function update (#130)
We need to make sure to generate `:lambda` expression for fallback cases, otherwise we will end up with the following code generation error: <https://github.com/JuliaLang/julia/blob/38d24e574caab20529a61a6f7444c9e473724ccc/src/method.c#L603>
1 parent 87731ed commit c2e80f2

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

src/stage1/generated.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ struct ∂⃖recurse{N}; end
66

77
include("recurse.jl")
88

9+
function generate_lambda_ex(world::UInt, source::LineNumberNode,
10+
args::Core.SimpleVector, sparams::Core.SimpleVector, body::Expr)
11+
stub = Core.GeneratedFunctionStub(identity, args, sparams)
12+
return stub(world, source, body)
13+
end
14+
915
function perform_optic_transform(world::UInt, source::LineNumberNode,
1016
@nospecialize(ff::Type{∂⃖recurse{N}}), @nospecialize(args)) where {N}
1117
@assert N >= 1
@@ -15,8 +21,8 @@ function perform_optic_transform(world::UInt, source::LineNumberNode,
1521
mthds = Base._methods_by_ftype(sig, -1, world)
1622
if mthds === nothing || length(mthds) != 1
1723
# Core.println("[perform_optic_transform] ", sig, " => ", mthds)
18-
stub = Core.GeneratedFunctionStub(identity, Core.svec(:ff, :args), Core.svec())
19-
return stub(world, source, :(throw(MethodError(ff, args))))
24+
return generate_lambda_ex(world, source,
25+
Core.svec(:ff, :args), Core.svec(), :(throw(MethodError(ff, args))))
2026
end
2127
match = only(mthds)::Core.MethodMatch
2228

src/stage1/recurse_fwd.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,17 @@ end
3131
function perform_fwd_transform(world::UInt, source::LineNumberNode,
3232
@nospecialize(ff::Type{∂☆recurse{N}}), @nospecialize(args)) where {N}
3333
if all(x->x <: ZeroBundle, args)
34-
return :(∂☆passthrough(args))
34+
return generate_lambda_ex(world, source,
35+
Core.svec(:ff, :args), Core.svec(), :(∂☆passthrough(args)))
3536
end
3637

3738
# Check if we have an rrule for this function
3839
sig = Tuple{map(π, args)...}
3940
mthds = Base._methods_by_ftype(sig, -1, world)
4041
if mthds === nothing || length(mthds) != 1
4142
# Core.println("[perform_fwd_transform] ", sig, " => ", mthds)
42-
stub = Core.GeneratedFunctionStub(identity, Core.svec(:ff, :args), Core.svec())
43-
return stub(world, source, :(∂☆nomethd(args)))
43+
return generate_lambda_ex(world, source,
44+
Core.svec(:ff, :args), Core.svec(), :(∂☆nomethd(args)))
4445
end
4546
match = only(mthds)::Core.MethodMatch
4647

0 commit comments

Comments
 (0)