diff --git a/Project.toml b/Project.toml index a11d100a..eafdfe58 100644 --- a/Project.toml +++ b/Project.toml @@ -22,11 +22,11 @@ ChainRules = "1.44.6" ChainRulesCore = "1.20" Combinatorics = "1" Compiler = "~0" -Cthulhu = "2.10.1" +Cthulhu = "2.16.3" OffsetArrays = "1" PrecompileTools = "1" StaticArrays = "1" -StructArrays = "0.6" +StructArrays = "0.6, 0.7" julia = "1.10" [extras] diff --git a/src/analysis/forward.jl b/src/analysis/forward.jl index 2afe827c..c9936b06 100644 --- a/src/analysis/forward.jl +++ b/src/analysis/forward.jl @@ -34,7 +34,7 @@ function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize # discover what they are. frules should be written in such a way that # whether or not they return `nothing`, only depends on the non-tangent arguments frule_arginfo = ArgInfo(nothing, frule_argtypes) - frule_si = StmtInfo(true) + frule_si = StmtInfo(true, false) # turn off frule analysis in the frule to avoid cycling interp′ = disable_forward(interp) frule_call = CC.abstract_call_gf_by_type(interp′, diff --git a/src/codegen/forward_demand.jl b/src/codegen/forward_demand.jl index 9681c591..d8eea936 100644 --- a/src/codegen/forward_demand.jl +++ b/src/codegen/forward_demand.jl @@ -352,11 +352,11 @@ function forward_diff!(interp::ADInterpreter, ir::IRCode, src::CodeInfo, mi::Met end end - method_info = CC.MethodInfo(src) + info = @static VERSION ≥ v"1.12.0-DEV.1293" ? CC.SpecInfo(src) : CC.MethodInfo(src) argtypes = ir.argtypes[1:mi.def.nargs] world = get_inference_world(interp) - irsv = IRInterpretationState(interp, method_info, ir, mi, argtypes, world, src.min_world, src.max_world) - rt = CC._ir_abstract_constant_propagation(interp, irsv) + irsv = IRInterpretationState(interp, info, ir, mi, argtypes, world, src.min_world, src.max_world) + rt = CC.ir_abstract_constant_propagation(interp, irsv) ir = compact!(ir) diff --git a/src/codegen/reverse.jl b/src/codegen/reverse.jl index 276405c0..a7eb5263 100644 --- a/src/codegen/reverse.jl +++ b/src/codegen/reverse.jl @@ -14,12 +14,13 @@ function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, ci, ocm = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any), typ, Union{}, rettype, @__MODULE__, ci, lno.line, lno.file, meth_nargs, isva, ()).source end - return Expr(:new_opaque_closure, typ, Union{}, Any, ocm, revs...) else oc_nargs = Int64(meth_nargs) - Expr(:new_opaque_closure, typ, Union{}, Any, - Expr(:opaque_closure_method, name, oc_nargs, isva, lno, ci), revs...) + ocm = Expr(:opaque_closure_method, name, oc_nargs, isva, lno, ci) end + oc = Expr(:new_opaque_closure, typ, Union{}, Any, true, ocm, revs...) + @static VERSION < v"1.12.0-DEV.691" ? deleteat!(oc.args, 4) : nothing + oc end function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::Int, interp=nothing, curs=nothing) diff --git a/src/extra_rules.jl b/src/extra_rules.jl index 5dab4dd1..d21a0acd 100644 --- a/src/extra_rules.jl +++ b/src/extra_rules.jl @@ -268,6 +268,12 @@ function ChainRulesCore.rrule(::DiffractorRuleConfig, ::Type{InplaceableThunk}, val, Δ->(NoTangent(), NoTangent(), Δ) end +# XXX: We should instead skip differentiation in the IR. +function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(getproperty), mod::Module, name::Symbol) + val = getproperty(mod, name) + val, Δ->(NoTangent(), NoTangent(), NoTangent()) +end + Base.real(z::NoTangent) = z # TODO should be in CRC, https://github.com/JuliaDiff/ChainRulesCore.jl/pull/581 # Avoid https://github.com/JuliaDiff/ChainRulesCore.jl/pull/495 diff --git a/src/stage1/compiler_utils.jl b/src/stage1/compiler_utils.jl index 93fb30c1..21329679 100644 --- a/src/stage1/compiler_utils.jl +++ b/src/stage1/compiler_utils.jl @@ -1,5 +1,5 @@ # Utilities that should probably go into CC -using .Compiler: IRCode, CFG, BasicBlock, BBIdxIter +using .CC: IRCode, CFG, BasicBlock, BBIdxIter function Base.push!(cfg::CFG, bb::BasicBlock) @assert cfg.blocks[end].stmts.stop+1 == bb.stmts.start @@ -30,10 +30,6 @@ if VERSION < v"1.12.0-DEV.1268" Base.copy(ir::IRCode) = CC.copy(ir) - CC.BasicBlock(x::UnitRange) = - BasicBlock(StmtRange(first(x), last(x))) - CC.BasicBlock(x::UnitRange, preds::Vector{Int}, succs::Vector{Int}) = - BasicBlock(StmtRange(first(x), last(x)), preds, succs) Base.length(c::CC.NewNodeStream) = CC.length(c) Base.setindex!(i::Instruction, args...) = CC.setindex!(i, args...) Base.size(x::CC.UnitRange) = CC.size(x) diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index a1b861fa..4ba91d29 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -6,7 +6,8 @@ struct ∂⃖recurse{N}; end include("recurse.jl") -function generate_lambda_ex(world::UInt, source::LineNumberNode, +# source is a Method starting from https://github.com/JuliaLang/julia/pull/57230 +function generate_lambda_ex(world::UInt, source::Union{Method,LineNumberNode}, args::Core.SimpleVector, sparams::Core.SimpleVector, body::Expr) stub = Core.GeneratedFunctionStub(identity, args, sparams) return stub(world, source, body) @@ -16,7 +17,7 @@ struct NonTransformableError args end -function perform_optic_transform(world::UInt, source::LineNumberNode, +function perform_optic_transform(world::UInt, source::Union{Method,LineNumberNode}, @nospecialize(ff::Type{∂⃖recurse{N}}), @nospecialize(args)) where {N} @assert N >= 1 diff --git a/src/stage1/recurse.jl b/src/stage1/recurse.jl index 1e09e881..1f562bb6 100644 --- a/src/stage1/recurse.jl +++ b/src/stage1/recurse.jl @@ -183,10 +183,10 @@ function split_critical_edges!(ir) bb = ir.stmts[i][:inst].args[1] ir.stmts[i][:inst] = nothing bbnew = bb + ninserted - insert!(cfg.blocks, bbnew, BasicBlock(i:i)) + insert!(cfg.blocks, bbnew, BasicBlock(StmtRange(i:i))) bb_rename_offset[bb] += 1 bblock = cfg.blocks[bbnew+1] - cfg.blocks[bbnew+1] = BasicBlock((i+1):last(bblock.stmts), + cfg.blocks[bbnew+1] = BasicBlock(StmtRange((i+1):last(bblock.stmts)), bblock.preds, bblock.succs) i += 1 while i <= last(bblock.stmts) diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index 954a98eb..e4f99348 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -222,7 +222,7 @@ function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int, E) return ci end -function perform_fwd_transform(world::UInt, source::LineNumberNode, +function perform_fwd_transform(world::UInt, source::Union{Method,LineNumberNode}, @nospecialize(ff::Type{∂☆recurse{N,E}}), @nospecialize(args)) where {N,E} if all(x->x <: ZeroBundle, args) return generate_lambda_ex(world, source, diff --git a/src/stage2/forward.jl b/src/stage2/forward.jl index dd0bfdb1..01cb4638 100644 --- a/src/stage2/forward.jl +++ b/src/stage2/forward.jl @@ -21,12 +21,11 @@ end # unlikely to be the actual interface. For now, it is used for testing. function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1; eras_mode = false) interp = ADInterpreter(; forward=true, backward=false) - match = Base._which(tt) - frame = CC.typeinf_frame(interp, match.method, match.spec_types, match.sparams, #=run_optimizer=#true) - mi = frame.linfo + mi = @ccall jl_method_lookup_by_tt(tt::Any, Base.tls_world_age()::Csize_t, #= method table =# nothing::Any)::Ref{MethodInstance} + ci = CC.typeinf_ext_toplevel(interp, mi, CC.SOURCE_MODE_ABI) src = CC.copy(interp.unopt[0][mi].src) - ir = CC.copy((@atomic :monotonic interp.opt[0][mi].inferred).ir::IRCode) + ir = CC.copy((@atomic :monotonic ci.inferred).ir::IRCode) # Find all Return Nodes vals = Pair{SSAValue, Int}[] @@ -83,6 +82,7 @@ function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1; eras_mode = fa end ir = forward_diff!(interp, ir, src, mi, vals; visit_custom!, transform!, eras_mode) + ir.argtypes[1] = Tuple{} return OpaqueClosure(ir) end diff --git a/src/stage2/interpreter.jl b/src/stage2/interpreter.jl index 1351a4ad..2d11bbcc 100644 --- a/src/stage2/interpreter.jl +++ b/src/stage2/interpreter.jl @@ -273,12 +273,76 @@ end # TODO: `get_remarks` should get a cursor? Cthulhu.get_remarks(interp::ADInterpreter, key::Union{MethodInstance,InferenceResult}) = get(interp.remarks[interp.current_level], key, nothing) -function CC.finish(sv::InferenceState, interp::ADInterpreter) - res = @invoke CC.finish(sv::InferenceState, interp::AbstractInterpreter) - key = (@static VERSION ≥ v"1.12.0-DEV.317" ? CC.is_constproped(sv) : CC.any(sv.result.overridden_by_const)) ? sv.result : sv.linfo - interp.unopt[interp.current_level][key] = Cthulhu.InferredSource(sv) +@static if VERSION ≥ v"1.13.0-DEV.126" +function diffractor_finish(@specialize(finishfunc), state::InferenceState, interp::ADInterpreter, cycleid::Int) + res = @invoke finishfunc(state::InferenceState, interp::AbstractInterpreter, cycleid::Int) + key = CC.is_constproped(state) ? state.result : state.linfo + interp.unopt[interp.current_level][key] = Cthulhu.InferredSource(state) + return res +end +else +function diffractor_finish(@specialize(finishfunc), state::InferenceState, interp::ADInterpreter) + res = @invoke finishfunc(state::InferenceState, interp::AbstractInterpreter) + key = (@static VERSION ≥ v"1.12.0-DEV.317" ? CC.is_constproped(state) : CC.any(state.result.overridden_by_const)) ? state.result : state.linfo + interp.unopt[interp.current_level][key] = Cthulhu.InferredSource(state) return res end +end + +@static if VERSION ≥ v"1.12.0-DEV.1823" +@static if VERSION ≥ v"1.13.0-DEV.126" || VERSION ≥ v"1.12.0-alpha1" +CC.finishinfer!(state::InferenceState, interp::ADInterpreter, cycleid::Int) = diffractor_finish(CC.finishinfer!, state, interp, cycleid) +else +CC.finishinfer!(state::InferenceState, interp::ADInterpreter) = diffractor_finish(CC.finishinfer!, state, interp) +end +@static if VERSION ≥ v"1.12.0-DEV.1988" +function CC.finish!(interp::ADInterpreter, caller::InferenceState, validation_world::UInt) + Cthulhu.set_cthulhu_source!(caller.result) + return @invoke CC.finish!(interp::AbstractInterpreter, caller::InferenceState, validation_world::UInt) +end +else +function CC.finish!(interp::ADInterpreter, caller::InferenceState) + Cthulhu.set_cthulhu_source!(caller.result) + return @invoke CC.finish!(interp::AbstractInterpreter, caller::InferenceState) +end +end + +elseif VERSION ≥ v"1.12.0-DEV.734" +CC.finishinfer!(state::InferenceState, interp::ADInterpreter) = diffractor_finish(CC.finishinfer!, state, interp) +function CC.finish!(interp::ADInterpreter, caller::InferenceState; + can_discard_trees::Bool=false) + Cthulhu.set_cthulhu_source!(caller.result) + return @invoke CC.finish!(interp::AbstractInterpreter, caller::InferenceState; + can_discard_trees) +end + +elseif VERSION ≥ v"1.11.0-DEV.737" +CC.finish(state::InferenceState, interp::ADInterpreter) = diffractor_finish(CC.finish, state, interp) +function CC.finish!(interp::ADInterpreter, caller::InferenceState) + result = caller.result + opt = result.src + Cthulhu.set_cthulhu_source!(result) + if opt isa CC.OptimizationState + CC.ir_to_codeinf!(opt) + end + return nothing +end +function CC.transform_result_for_cache(::ADInterpreter, ::MethodInstance, ::WorldRange, + result::InferenceResult) + return result.src +end + +else # VERSION < v"1.11.0-DEV.737" +CC.finish(state::InferenceState, interp::ADInterpreter) = diffractor_finish(CC.finish, state, interp) +function CC.transform_result_for_cache(::ADInterpreter, ::MethodInstance, ::WorldRange, + result::InferenceResult) + return create_cthulhu_source(result.src, result.ipo_effects) +end +function CC.finish!(::ADInterpreter, caller::InferenceResult) + Cthulhu.set_cthulhu_source(interp, caller) +end + +end # @static if const StmtFlag = @static VERSION ≥ v"1.11.0-DEV.377" ? UInt32 : UInt8 function diffractor_inlining_policy(@nospecialize(src), @nospecialize(info::CC.CallInfo), @@ -303,10 +367,6 @@ function diffractor_inlining_policy(@nospecialize(src), @nospecialize(info::CC.C end @static if VERSION ≥ v"1.12.0-DEV.45" -function CC.transform_result_for_cache(interp::ADInterpreter, - ::MethodInstance, ::WorldRange, result::InferenceResult, ::Bool) - return Cthulhu.create_cthulhu_source(result.src, result.ipo_effects) -end function CC.src_inlining_policy(interp::ADInterpreter, @nospecialize(src), @nospecialize(info::CC.CallInfo), stmt_flag::StmtFlag) ret = diffractor_inlining_policy(src, info, stmt_flag) @@ -316,10 +376,6 @@ function CC.src_inlining_policy(interp::ADInterpreter, src::Any, info::CC.CallInfo, stmt_flag::StmtFlag) end else -function CC.transform_result_for_cache(interp::ADInterpreter, - linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult) - return Cthulhu.create_cthulhu_source(result.src, result.ipo_effects) -end function CC.inlining_policy(interp::ADInterpreter, @nospecialize(src), @nospecialize(info::CC.CallInfo), stmt_flag::StmtFlag, mi::MethodInstance, argtypes::Vector{Any}) @@ -351,17 +407,6 @@ function CC.optimize(interp::ADInterpreter, opt::OptimizationState, end =# -function _finish!(caller::InferenceResult) - effects = caller.ipo_effects - caller.src = Cthulhu.create_cthulhu_source(caller.src, effects) -end - -@static if VERSION ≥ v"1.11.0-DEV.737" -CC.finish!(::ADInterpreter, caller::InferenceState) = _finish!(caller.result) -else -CC.finish!(::ADInterpreter, caller::InferenceResult) = _finish!(caller) -end - @static if VERSION ≥ v"1.11.0-DEV.1278" function CC.bail_out_const_call(interp::ADInterpreter, result::CC.MethodCallResult, si::StmtInfo, sv::CC.AbsIntState) diff --git a/test/forward_diff_no_inf.jl b/test/forward_diff_no_inf.jl index a3f62b82..f1c60f18 100644 --- a/test/forward_diff_no_inf.jl +++ b/test/forward_diff_no_inf.jl @@ -31,11 +31,15 @@ module forward_diff_no_inf ir[SSAValue(i)][:flag] |= CC.IR_FLAG_REFINED end - method_info = CC.MethodInfo(#=propagate_inbounds=#true, nothing) + info = @static if VERSION ≥ v"1.12.0-DEV.1293" + CC.SpecInfo(#=nargs=#length(ir.argtypes), #=isva=#false, #=propagate_inbounds=#true, nothing) + else + CC.MethodInfo(#=propagate_inbounds=#true, nothing) + end min_world = world = (interp).world max_world = Diffractor.get_world_counter() - irsv = CC.IRInterpretationState(interp, method_info, ir, mi, ir.argtypes, world, min_world, max_world) - (rt, nothrow) = CC._ir_abstract_constant_propagation(interp, irsv) + irsv = CC.IRInterpretationState(interp, info, ir, mi, ir.argtypes, world, min_world, max_world) + (rt, nothrow) = CC.ir_abstract_constant_propagation(interp, irsv) return rt end @@ -79,6 +83,7 @@ module forward_diff_no_inf ir = first(only(Base.code_ircode(foo_148, Tuple{Float64}))) Diffractor.forward_diff_no_inf!(ir, [SSAValue(1) => 1]; transform! = identity_transform!) ir2 = CC.compact!(ir) + ir2.argtypes[1] = Tuple{} f = Core.OpaqueClosure(ir2; do_compile=false) @test f(1.0) == Bar148(1.0) # This would error if we were not handling constructors (%new) right end @@ -96,6 +101,7 @@ module forward_diff_no_inf stmt = ir2.stmts[stmt_idx] @test stmt[:inst].name == :_coeff @test stmt[:type] == Float64 + ir2.argtypes[1] = Tuple{} f = Core.OpaqueClosure(ir2; do_compile=false) @test f(3.5) == 28.0 end @@ -124,6 +130,7 @@ module forward_diff_no_inf Diffractor.forward_diff_no_inf!(ir, diff_ssa .=> 1; transform! = identity_transform!) ir2 = CC.compact!(ir) CC.verify_ir(ir2) # This would error if we were not handling nonconst phi nodes correctly (after https://github.com/JuliaLang/julia/pull/50158) + ir2.argtypes[1] = Tuple{} f = Core.OpaqueClosure(ir2; do_compile=false) @test f(3.5) == 3.5 # this will segfault if we are not handling phi nodes correctly end @@ -154,4 +161,3 @@ module forward_diff_no_inf end end end # module - diff --git a/test/gradcheck.jl b/test/gradcheck.jl index bfada096..ca8313f3 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -95,7 +95,8 @@ end @testset "sum, prod" begin @test gradcheck(x -> sum(abs2, x), randn(4, 3, 2)) - @test gradcheck(x -> sum(x[i] for i in 1:length(x)), randn(10)) + # Fails in `diffract_ir!` on $(Expr(:isdefined, :($(Expr(:static_parameter, 1))))) + @test_broken gradcheck(x -> sum(x[i] for i in 1:length(x)), randn(10)) @test gradcheck(x -> sum(i->x[i], 1:length(x)), randn(10)) # issue #231 @test gradcheck(x -> sum((i->x[i]).(1:length(x))), randn(10)) @test gradcheck(X -> sum(x -> x^2, X), randn(10)) diff --git a/test/reverse.jl b/test/reverse.jl index 023be88c..2264917d 100644 --- a/test/reverse.jl +++ b/test/reverse.jl @@ -70,9 +70,9 @@ let var"'" = Diffractor.PrimeDerivativeBack # Integration tests @test @inferred(sin'(1.0)) == cos(1.0) @test @inferred(sin''(1.0)) == -sin(1.0) - @test @inferred(sin'''(1.0)) == -cos(1.0) # FIXME: These error with: # Control flow support not fully implemented yet for higher-order reverse mode (TODO) + @test_broken @inferred(sin'''(1.0)) == -cos(1.0) @test_broken @inferred(sin''''(1.0)) == sin(1.0) @test_broken @inferred(sin'''''(1.0)) == cos(1.0) @test_broken @inferred(sin''''''(1.0)) == -sin(1.0)