11
22module forward_diff_no_inf
3- using Core. Compiler : SSAValue
3+ using Core: SSAValue
44 const CC = Core. Compiler
55
66 using Diffractor, Test
@@ -20,7 +20,14 @@ module forward_diff_no_inf
2020 mi. specTypes = Tuple{map (CC. widenconst, ir. argtypes)... }
2121 mi. def = @__MODULE__
2222
23- for i in 1 : length (ir. stmts) # For testuing purposes we are going to refine everything
23+ for i in 1 : length (ir. stmts)
24+ inst = ir[SSAValue (i)][:inst ]
25+ if Meta. isexpr (inst, :code_coverage_effect )
26+ # delete these as CC._ir_abstract_constant_propagation doesn't work on them
27+ ir[SSAValue (i)][:inst ] = nothing
28+ ir[SSAValue (i)][:type ] = Nothing
29+ end
30+ # For testing purposes we are going to refine everything else
2431 ir[SSAValue (i)][:flag ] |= CC. IR_FLAG_REFINED
2532 end
2633
@@ -39,14 +46,28 @@ module forward_diff_no_inf
3946 typ = stmt[:type ]
4047 ! isa (typ, Type) && continue # If not a Type then something even more informed like a Const
4148 if isabstracttype (typ) || typ <: Union || typ <: UnionAll
42- # @error "Not fully inferred" inst typ
49+ # @error "Not fully inferred" inst typ
4350 return false
4451 end
4552 end
4653 end
4754 return true
4855 end
4956
57+ function findfirst_ssa (predicate, ir)
58+ for ii in 1 : length (ir. stmts)
59+ try
60+ inst = ir[SSAValue (ii)][:inst ]
61+ if predicate (inst)
62+ return SSAValue (ii)
63+ end
64+ catch
65+ # ignore errors so predicate can be simple
66+ end
67+ end
68+ return nothing
69+ end
70+
5071 # ############################## Actual tests:
5172
5273 @testset " Constructors in forward_diff_no_inf!" begin
@@ -108,21 +129,22 @@ module forward_diff_no_inf
108129 end
109130
110131 # only test this on new enough julia versions as exactly what infers can be fussy, as is running inference manually
111- VERSION >= v " 1.12.0-DEV.283" && @testset " Eras mode: $eras_mode " for eras_mode in (false , true )
132+ VERSION >= v " 1.12.0-DEV.283" && @testset " Eras mode: $eras_mode " for eras_mode in (false , true )
112133 foo (x, y) = x* x + y* y
113134 ir = first (only (Base. code_ircode (foo, Tuple{Any, Any})))
114- Diffractor. forward_diff_no_inf! (ir, [SSAValue (1 )] .=> 1 ; transform! = identity_transform!, eras_mode)
135+ mul1_ssa = findfirst_ssa (x-> x. args[1 ]. name== :* , ir)
136+ Diffractor. forward_diff_no_inf! (ir, [mul1_ssa] .=> 1 ; transform! = identity_transform!, eras_mode)
115137 ir = CC. compact! (ir)
116138 ir. argtypes[2 : end ] .= Float64
117139 ir = CC. compact! (ir)
118140 infer_ir! (ir)
119141 CC. verify_ir (ir)
120142 @test isfully_inferred (ir) # passes with and without eras mode
121-
122- Diffractor. forward_diff_no_inf! (ir, [SSAValue (3 )] .=> 1 ; transform! = identity_transform!, eras_mode)
143+
144+ add_ssa = findfirst_ssa (x-> x. args[1 ]. name== :+ , ir)
145+ Diffractor. forward_diff_no_inf! (ir, [add_ssa] .=> 1 ; transform! = identity_transform!, eras_mode)
123146 ir = CC. compact! (ir)
124147 infer_ir! (ir)
125-
126148 CC. verify_ir (ir)
127149 if eras_mode
128150 @test isfully_inferred (ir)
@@ -131,6 +153,5 @@ module forward_diff_no_inf
131153 @assert ! isfully_inferred (ir)
132154 end
133155 end
134-
135156end # module
136157
0 commit comments