diff --git a/src/driver.jl b/src/driver.jl index 507e0866..3420aaea 100644 --- a/src/driver.jl +++ b/src/driver.jl @@ -406,7 +406,7 @@ end if validate @timeit_debug to "validation" begin check_invocation(job) - check_ir(job, ir) + check_llvm_ir(job, ir) end end diff --git a/src/interface.jl b/src/interface.jl index 57ebb013..25311c56 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -183,7 +183,7 @@ isintrinsic(@nospecialize(job::CompilerJob), fn::String) = false # provide a specific interpreter to use. get_interpreter(@nospecialize(job::CompilerJob)) = - GPUInterpreter(ci_cache(job), method_table(job), job.source.world) + GPUInterpreter(job, ci_cache(job), method_table(job),) # does this target support throwing Julia exceptions with jl_throw? # if not, calls to throw will be replaced with calls to the GPU runtime diff --git a/src/jlgen.jl b/src/jlgen.jl index 1a2f038f..37d10567 100644 --- a/src/jlgen.jl +++ b/src/jlgen.jl @@ -170,31 +170,29 @@ using Core.Compiler: AbstractInterpreter, InferenceResult, InferenceParams, InferenceState, OptimizationParams struct GPUInterpreter <: AbstractInterpreter + job::CompilerJob + global_cache::CodeCache method_table::Union{Nothing,Core.MethodTable} # Cache of inference results for this particular interpreter local_cache::Vector{InferenceResult} - # The world age we're working inside of - world::UInt # Parameters for inference and optimization inf_params::InferenceParams opt_params::OptimizationParams - function GPUInterpreter(cache::CodeCache, mt::Union{Nothing,Core.MethodTable}, world::UInt) - @assert world <= Base.get_world_counter() + function GPUInterpreter(job::CompilerJob, cache::CodeCache, mt::Union{Nothing,Core.MethodTable}) + @assert job.source.world <= Base.get_world_counter() return new( + job, cache, mt, # Initially empty cache Vector{InferenceResult}(), - # world age counter - world, - # parameters for inference and optimization InferenceParams(unoptimize_throw_blocks=false), VERSION >= v"1.8.0-DEV.486" ? OptimizationParams() : @@ -205,14 +203,28 @@ end Core.Compiler.InferenceParams(interp::GPUInterpreter) = interp.inf_params Core.Compiler.OptimizationParams(interp::GPUInterpreter) = interp.opt_params -Core.Compiler.get_world_counter(interp::GPUInterpreter) = interp.world +Core.Compiler.get_world_counter(interp::GPUInterpreter) = interp.job.source.world Core.Compiler.get_inference_cache(interp::GPUInterpreter) = interp.local_cache -Core.Compiler.code_cache(interp::GPUInterpreter) = WorldView(interp.global_cache, interp.world) +Core.Compiler.code_cache(interp::GPUInterpreter) = + WorldView(interp.global_cache, Core.Compiler.get_world_counter(interp)) # No need to do any locking since we're not putting our results into the runtime cache Core.Compiler.lock_mi_inference(interp::GPUInterpreter, mi::MethodInstance) = nothing Core.Compiler.unlock_mi_inference(interp::GPUInterpreter, mi::MethodInstance) = nothing +import Core.Compiler: retrieve_code_info, validate_code_in_debug_mode, InferenceState +# Replace usage sites of `retrieve_code_info`, OptimizationState is one such use, but in all +# interesting use-cases it is derived from an InferenceState. There is a third one in +# `typeinf_ext` in case the module forbids inference. +function InferenceState(result::InferenceResult, cached::Symbol, interp::GPUInterpreter) + src = retrieve_code_info(result.linfo) + src === nothing && return nothing + validate_code_in_debug_mode(result.linfo, src, "lowered") + check_julia_ir(interp, result.linfo, src) + return InferenceState(result, src, cached, interp) +end + + function Core.Compiler.add_remark!(interp::GPUInterpreter, sv::InferenceState, msg) @safe_debug "Inference remark during GPU compilation of $(sv.linfo): $msg" end @@ -228,14 +240,14 @@ if isdefined(Base.Experimental, Symbol("@overlay")) using Core.Compiler: OverlayMethodTable if v"1.8-beta2" <= VERSION < v"1.9-" || VERSION >= v"1.9.0-DEV.120" Core.Compiler.method_table(interp::GPUInterpreter) = - OverlayMethodTable(interp.world, interp.method_table) + OverlayMethodTable(Core.Compiler.get_world_counter(interp), interp.method_table) else Core.Compiler.method_table(interp::GPUInterpreter, sv::InferenceState) = - OverlayMethodTable(interp.world, interp.method_table) + OverlayMethodTable(Core.Compiler.get_world_counter(interp), interp.method_table) end else Core.Compiler.method_table(interp::GPUInterpreter, sv::InferenceState) = - WorldOverlayMethodTable(interp.world) + WorldOverlayMethodTable(Core.Compiler.get_world_counter(interp)) end diff --git a/src/validation.jl b/src/validation.jl index 47743956..338f5022 100644 --- a/src/validation.jl +++ b/src/validation.jl @@ -32,7 +32,7 @@ function check_method(@nospecialize(job::CompilerJob)) if job.source.kernel cache = ci_cache(job) mt = method_table(job) - interp = GPUInterpreter(cache, mt, world) + interp = GPUInterpreter(job, cache, mt) rt = return_type(only(ms); interp) if rt != Nothing @@ -102,6 +102,55 @@ struct InvalidIRError <: Exception errors::Vector{IRError} end +# Julia IR + +const UNDEFINED_GLOBAL = "use of an undefined global binding" +const MUTABLE_GLOBAL = "use of a mutable global binding" + +function check_julia_ir(interp, mi, src) + # pseudo (single-frame) backtrace pointing to a source code location + function backtrace(i) + loc = src.linetable[i] + [StackTraces.StackFrame(loc.method, loc.file, loc.line, mi, false, false, C_NULL)] + end + + function check(i, x, errors::Vector{IRError}) + if x isa Expr + for y in x.args + check(i, y, errors) + end + elseif x isa GlobalRef + Base.isbindingresolved(x.mod, x.name) || return + # XXX: when does this happen? do we miss any cases by bailing out early? + # why doesn't calling `Base.resolve(x, force=true)` work? + if !Base.isdefined(x.mod, x.name) + push!(errors, (UNDEFINED_GLOBAL, backtrace(i), x)) + end + if !Base.isconst(x.mod, x.name) + push!(errors, (MUTABLE_GLOBAL, backtrace(i), x)) + end + + # TODO: make the validation conditional, but make sure we don't cache invalid IR + + # TODO: perform more validation? e.g. disallow Arrays and other CPU values? + end + + return + end + + errors = IRError[] + for (i, x) in enumerate(src.code) + check(i, x, errors) + end + if !isempty(errors) + throw(InvalidIRError(interp.job, errors)) + end + + return +end + +# LLVM IR + const RUNTIME_FUNCTION = "call to the Julia runtime" const UNKNOWN_FUNCTION = "call to an unknown function" const POINTER_FUNCTION = "call through a literal pointer" @@ -117,6 +166,8 @@ function Base.showerror(io::IO, err::InvalidIRError) print(io, " (call to ", meta, ")") elseif kind == DELAYED_BINDING print(io, " (use of '", meta, "')") + else + print(io, " (", meta, ")") end end Base.show_backtrace(io, bt) @@ -132,8 +183,8 @@ function Base.showerror(io::IO, err::InvalidIRError) return end -function check_ir(job, args...) - errors = check_ir!(job, IRError[], args...) +function check_llvm_ir(job, args...) + errors = check_llvm_ir!(job, IRError[], args...) unique!(errors) if !isempty(errors) throw(InvalidIRError(job, errors)) @@ -142,18 +193,18 @@ function check_ir(job, args...) return end -function check_ir!(job, errors::Vector{IRError}, mod::LLVM.Module) +function check_llvm_ir!(job, errors::Vector{IRError}, mod::LLVM.Module) for f in functions(mod) - check_ir!(job, errors, f) + check_llvm_ir!(job, errors, f) end return errors end -function check_ir!(job, errors::Vector{IRError}, f::LLVM.Function) +function check_llvm_ir!(job, errors::Vector{IRError}, f::LLVM.Function) for bb in blocks(f), inst in instructions(bb) if isa(inst, LLVM.CallInst) - check_ir!(job, errors, inst) + check_llvm_ir!(job, errors, inst) end end @@ -162,7 +213,7 @@ end const libjulia = Ref{Ptr{Cvoid}}(C_NULL) -function check_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst) +function check_llvm_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst) bt = backtrace(inst) dest = called_value(inst) if isa(dest, LLVM.Function)