Skip to content

Commit 4746542

Browse files
authored
Merge pull request #169 from JuliaDiff/sf/ssa_insteadof_stmt
Pass `ssa` instead of `stmt` to `visit_custom!()`
2 parents 53a3ded + 4076b60 commit 4746542

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

src/codegen/forward_demand.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ function forward_visit!(ir::IRCode, ssa::SSAValue, order::Int, ssa_orders::Vecto
166166
inst = ir[ssa]
167167
stmt = inst[:inst]
168168
recurse(@nospecialize(val)) = forward_visit!(ir, val, order, ssa_orders, visit_custom!)
169-
if visit_custom!(ir, stmt, order, recurse)
169+
if visit_custom!(ir, ssa, order, recurse)
170170
ssa_orders[ssa.id] = order => true
171171
return
172172
elseif isa(stmt, PiNode)
@@ -211,7 +211,7 @@ Internal method which generates the code for forward mode diffentiation
211211
- `to_diff`: collection of all SSA values for which the derivative is to be taken,
212212
paired with the order (first deriviative, second derivative etc)
213213
214-
- `visit_custom!(ir::IRCode, stmt, order::Int, recurse::Bool) -> Bool`:
214+
- `visit_custom!(ir::IRCode, ssa, order::Int, recurse::Bool) -> Bool`:
215215
decides if the custom `transform!` should be applied to a `stmt` or not
216216
Default: `false` for all statements
217217
- `transform!(ir::IRCode, ssa::SSAValue, order::Int)` mutates `ir` to do a custom tranformation.

src/stage2/forward.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@ function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1)
1919
end
2020
end
2121

22-
function visit_custom!(ir::IRCode, @nospecialize(stmt), order, recurse)
22+
function visit_custom!(ir::IRCode, ssa::Union{SSAValue,Argument}, order, recurse)
23+
if isa(ssa, Argument)
24+
return true
25+
end
26+
27+
stmt = ir[ssa][:inst]
2328
if isa(stmt, ReturnNode)
2429
recurse(stmt.val)
2530
return true
26-
elseif isa(stmt, Argument)
27-
return true
2831
else
2932
return false
3033
end

0 commit comments

Comments
 (0)