Skip to content

Commit 02927c3

Browse files
committed
Commit forgotten local changes
1 parent 9d5cb64 commit 02927c3

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

src/stage1/recurse.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -456,8 +456,10 @@ function transform!(ci, meth, nargs, sparams, N)
456456
return insert_node_rev!(Expr(:call, getfield, Argument(1),
457457
i - first(orig_bb_ranges[end]) + 1))
458458
elseif isa(current_env, BBEnv)
459+
bbidx = i - current_env.bb_start_idx + 1
460+
@assert bbidx > 0
459461
return insert_node_rev!(Expr(:call, getfield, current_env.ctx_obj,
460-
i - current_env.bb_start_idx + 1))
462+
bbidx))
461463
end
462464
error()
463465
end
@@ -489,9 +491,9 @@ function transform!(ci, meth, nargs, sparams, N)
489491
insert_node_rev!(Expr(:(=), accumulator, accumed))
490492
end
491493
end
492-
if !isa(stmt, Union{GotoNode, GotoIfNot})
493-
#current_env = BBEnv(access_ctx_map(bb+1),
494-
# first(ir.cfg.blocks[bb].stmts))
494+
if !isa(stmt, Union{GotoNode, GotoIfNot, ReturnNode})
495+
current_env = BBEnv(access_ctx_map(bb+1),
496+
first(ir.cfg.blocks[bb].stmts))
495497
end
496498
end
497499

test/runtests.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,15 @@ function myprod(xs)
6060
return s
6161
end
6262

63+
function mypow(x, n)
64+
r = one(x)
65+
while n > 0
66+
n -= 1
67+
r *= x
68+
end
69+
return r
70+
end
71+
6372
function times_three_while(x)
6473
z = x
6574
i = 3
@@ -95,10 +104,14 @@ let var"'" = Diffractor.PrimeDerivativeBack
95104
@test @inferred(complicated_2sin'''(1.0)) == 2sin'''(1.0)
96105
@test @inferred(complicated_2sin''''(1.0)) == 2sin''''(1.0)
97106

107+
# Control flow cases
98108
@test @inferred((x->simple_control_flow(true, x))'(1.0)) == sin'(1.0)
99109
@test @inferred((x->simple_control_flow(false, x))'(1.0)) == cos'(1.0)
100110
@test (x->sum(isa_control_flow(Matrix{Float64}, x)))'(Float32[1 2;]) == [1.0 1.0;]
101111
@test times_three_while'(1.0) == 3.0
112+
113+
pow5p(x) = (x->mypow(x, 5))'(x)
114+
@test pow5p(1.0) == 5.0
102115
end
103116

104117
# Simple Forward Mode tests

0 commit comments

Comments
 (0)