@@ -248,9 +248,9 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
248
248
# identify where to insert. Must be after phi blocks
249
249
pos = SSAValue (find_end_of_phi_block (ir, arg. id))
250
250
if order == 0
251
- insert_node! (ir, pos, NewInstruction (Expr (:call , primal, arg), Any ), #= attach_after=# true )
251
+ insert_node! (ir, pos, NewInstruction (Expr (:call , primal, arg)), #= attach_after=# true )
252
252
else
253
- insert_node! (ir, pos, NewInstruction (Expr (:call , truncate, arg, Val {order} ()), Any ), #= attach_after=# true )
253
+ insert_node! (ir, pos, NewInstruction (Expr (:call , truncate, arg, Val {order} ())), #= attach_after=# true )
254
254
end
255
255
end
256
256
end
@@ -262,7 +262,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
262
262
return transform! (ir, arg, order, maparg)
263
263
elseif isa (arg, GlobalRef)
264
264
@assert isconst (arg)
265
- return insert_node! (ir, ssa, NewInstruction (Expr (:call , ZeroBundle{order}, arg), Any ))
265
+ return insert_node! (ir, ssa, NewInstruction (Expr (:call , ZeroBundle{order}, arg)))
266
266
elseif isa (arg, QuoteNode)
267
267
return ZeroBundle {order} (arg. value)
268
268
end
@@ -300,6 +300,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
300
300
# TODO : New PiNode that discriminates based on primal?
301
301
inst[:inst ] = maparg (stmt. val, SSAValue (ssa), order)
302
302
inst[:type ] = Any
303
+ inst[:flag ] |= CC. IR_FLAG_REFINED
303
304
elseif isa (stmt, GlobalRef)
304
305
if ! isconst (stmt)
305
306
# Non-const GlobalRefs need to need to be accessed as seperate statements
@@ -310,6 +311,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
310
311
elseif isa (stmt, SSAValue) || isa (stmt, QuoteNode)
311
312
inst[:inst ] = maparg (stmt, SSAValue (ssa), order)
312
313
inst[:type ] = Any
314
+ inst[:flag ] |= CC. IR_FLAG_REFINED
313
315
elseif isa (stmt, Expr) || isa (stmt, PhiNode) || isa (stmt, PhiCNode) ||
314
316
isa (stmt, UpsilonNode) || isa (stmt, GotoIfNot) || isa (stmt, Argument)
315
317
urs = userefs (stmt)
@@ -318,6 +320,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
318
320
end
319
321
inst[:inst ] = urs[]
320
322
inst[:type ] = Any
323
+ inst[:flag ] |= CC. IR_FLAG_REFINED
321
324
else
322
325
val = ZeroBundle {order} (inst[:inst ])
323
326
inst[:inst ] = val
@@ -336,10 +339,12 @@ function forward_diff!(interp::ADInterpreter, ir::IRCode, src::CodeInfo, mi::Met
336
339
ir = compact! (ir)
337
340
338
341
for i = 1 : length (ir. stmts)
339
- if ir[SSAValue (i)][:type ] == Any
340
- # TODO : this flag should actually be being set at the insert site
341
- # and we should be filtering on if it is present rather than [:type]=Any
342
- ir[SSAValue (i)][:flag ] |= CC. IR_FLAG_REFINED
342
+ inst = ir[SSAValue (i)][:inst ]
343
+ if ! isa (inst, ReturnNode) && ir[SSAValue (i)][:type ] === Any
344
+ if iszero (ir[SSAValue (i)][:flag ] & CC. IR_FLAG_REFINED)
345
+ @warn " IR_FLAG_REFINED Flag missed on statement" i inst
346
+ ir[SSAValue (i)][:flag ] |= CC. IR_FLAG_REFINED
347
+ end
343
348
end
344
349
end
345
350
0 commit comments