Skip to content

Commit 01cfc33

Browse files
author
William Moses
committed
overload working
1 parent b8a925e commit 01cfc33

File tree

2 files changed

+39
-110
lines changed

2 files changed

+39
-110
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,13 @@ end
206206

207207
const _kernel_instances = Dict{Any, Any}()
208208

209+
struct LLVMFunc{F,tt}
210+
f::Union{F, Nothing}
211+
mod::String
212+
image
213+
entry::String
214+
end
215+
209216

210217
# compile to executable machine code
211218
function compile(job)
@@ -218,8 +225,8 @@ function compile(job)
218225
@show mod
219226
@show modstr
220227
# check if we'll need the device runtime
221-
undefined_fs = filter(collect(functions(meta.ir))) do f
222-
isdeclaration(f) && !CUDA.LLVM.isintrinsic(f)
228+
undefined_fs = filter(collect(CUDA.LLVM.functions(meta.ir))) do f
229+
CUDA.LLVM.isdeclaration(f) && !CUDA.LLVM.isintrinsic(f)
223230
end
224231
intrinsic_fns = ["vprintf", "malloc", "free", "__assertfail",
225232
"__nvvm_reflect" #= TODO: should have been optimized away =#]
@@ -246,7 +253,7 @@ function compile(job)
246253

247254
# validate use of parameter memory
248255
argtypes = filter([CUDA.KernelState, job.source.specTypes.parameters...]) do dt
249-
!isghosttype(dt) && !Core.Compiler.isconstType(dt)
256+
!CUDA.isghosttype(dt) && !Core.Compiler.isconstType(dt)
250257
end
251258
param_usage = sum(sizeof, argtypes)
252259
param_limit = 4096
@@ -268,7 +275,7 @@ function compile(job)
268275
end
269276

270277
for (i, typ) in enumerate(source_types)
271-
if isghosttype(typ) || Core.Compiler.isconstType(typ)
278+
if CUDA.isghosttype(typ) || Core.Compiler.isconstType(typ)
272279
continue
273280
end
274281
name = source_argnames[i]
@@ -306,7 +313,7 @@ function compile(job)
306313
"--output-file", ptxas_output,
307314
ptx_input
308315
])
309-
proc, log = CUDA.run_and_collect(`$(ptxas()) $ptxas_opts`)
316+
proc, log = CUDA.run_and_collect(`$(CUDA.ptxas()) $ptxas_opts`)
310317
log = strip(log)
311318
if !success(proc)
312319
reason = proc.termsignal > 0 ? "ptxas received signal $(proc.termsignal)" :
@@ -342,7 +349,7 @@ function compile(job)
342349
"--output-file", nvlink_output,
343350
ptxas_output
344351
])
345-
proc, log = run_and_collect(`$(nvlink()) $nvlink_opts`)
352+
proc, log = run_and_collect(`$(CUDA.nvlink()) $nvlink_opts`)
346353
log = strip(log)
347354
if !success(proc)
348355
reason = proc.termsignal > 0 ? "nvlink received signal $(proc.termsignal)" :
@@ -369,11 +376,7 @@ function compile(job)
369376
modstr, image, meta.entry
370377
end
371378

372-
println(string(modstr))
373-
@show job
374-
@show job.source
375-
@show job.config
376-
LLVMFunc{F,job.source.specTypes}(f, modstr, image, LLVM.name(entry))
379+
LLVMFunc{job.source.specTypes[1],job.source.specTypes}(nothing, modstr, image, LLVM.name(entry))
377380
end
378381

379382
# link into an executable kernel
@@ -382,13 +385,6 @@ function link(job, compiled)
382385
return compiled
383386
end
384387

385-
struct LLVMFunc{F,tt}
386-
f::F
387-
mod::String
388-
image
389-
entry::String
390-
end
391-
392388
function (func::LLVMFunc{F,tt})(args...; blocks::CUDA.CuDim=1, threads::CUDA.CuDim=1,
393389
shmem::Integer=0) where{F, tt}
394390
blockdim = CUDA.CuDim3(blocks)

src/utils.jl

Lines changed: 25 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,26 @@ macro LineInfoNode(method)
4545
end
4646

4747

48-
function rewrite_inst(inst)
49-
@show inst
48+
49+
function maybe_argextype(
50+
@nospecialize(x),
51+
src,
52+
)
53+
return try
54+
Core.Compiler.argextype(x, src)
55+
catch err
56+
!(err isa Core.Compiler.InvalidIRError) && rethrow()
57+
nothing
58+
end
59+
end
60+
61+
function rewrite_inst(inst, ir)
5062
if Meta.isexpr(inst, :call)
51-
rep = Expr(:call, call_with_reactant, inst.args...)
52-
@show rep
53-
return rep
63+
ft = Core.Compiler.widenconst(maybe_argextype(inst.args[1], ir))
64+
if !(ft <: Core.IntrinsicFunction) && !(ft <: Core.Builtin)
65+
rep = Expr(:call, call_with_reactant, inst.args...)
66+
return rep
67+
end
5468
end
5569
if Meta.isexpr(inst, :invoke)
5670
return Expr(:call, inst.args[2:end]...)
@@ -204,12 +218,11 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self,
204218
@nospecialize
205219

206220
args = redub_arguments
207-
@show args
208221

209222
stub = Core.GeneratedFunctionStub(identity, Core.svec(:call_with_reactant, :redub_arguments), Core.svec())
210223

211224
# look up the method match
212-
builtin_error = :(throw(AssertionError("Unsupported call_with_reactant of builtin $args")))
225+
builtin_error = :(throw(AssertionError("Unsupported call_with_reactant of builtin $redub_arguments")))
213226

214227
if args[1] <: Core.Builtin
215228
return stub(world, source, builtin_error)
@@ -218,7 +231,7 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self,
218231
method_error = :(throw(MethodError(args[1], args[2:end], $world)))
219232

220233
interp = ReactantInterpreter(; world)
221-
234+
222235
sig = Tuple{args...}
223236
lookup_result = Core.Compiler.findall(sig, Core.Compiler.method_table(interp)).matches
224237

@@ -239,7 +252,6 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self,
239252
(Any, Any, Any), match.method, match.spec_types, match.sparams)
240253

241254
result = Core.Compiler.InferenceResult(mi, Core.Compiler.typeinf_lattice(interp))
242-
@static if true
243255
frame = Core.Compiler.InferenceState(result, #=cache_mode=#:local, interp)
244256
@assert frame !== nothing
245257
Core.Compiler.typeinf(interp, frame)
@@ -260,19 +272,16 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self,
260272
end
261273

262274
for (i, inst) in enumerate(ir.stmts)
275+
263276
@static if VERSION < v"1.11"
264-
Core.Compiler.setindex!(ir.stmts[i], rewrite_inst(inst[:inst]), :inst)
277+
Core.Compiler.setindex!(ir.stmts[i], rewrite_inst(inst[:inst], ir), :inst)
265278
else
266-
Core.Compiler.setindex!(ir.stmts[i], rewrite_inst(inst[:stmt]), :stmt)
279+
Core.Compiler.setindex!(ir.stmts[i], rewrite_inst(inst[:stmt], ir), :stmt)
267280
end
281+
Core.Compiler.setindex!(ir.stmts[i], Any, :type)
268282
end
269283
Core.Compiler.finish(interp, opt, ir, caller)
270-
271284
src = Core.Compiler.ir_to_codeinf!(opt)
272-
#end
273-
else
274-
src = Core.Compiler.retrieve_code_info(mi, world)
275-
end
276285

277286
# prepare a new code info
278287
code_info = copy(src)
@@ -347,17 +356,9 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self,
347356

348357
# substitute static parameters, offset slot numbers by number of added slots, and
349358
# offset statement indices by the number of additional statements
350-
@show code_info.code
351359

352-
@show n_prepended_slots
353-
@static if false
354-
Base.Meta.partially_inline!(code_info.code, fn_args, method.sig, Any[static_params...],
355-
n_prepended_slots, length(overdubbed_code), :propagate)
356-
else
357360
arg_partially_inline!(code_info.code, fn_args, method.sig, Any[static_params...],
358361
n_prepended_slots, n_prepended_slots, length(overdubbed_code), :propagate)
359-
end
360-
@show code_info.code
361362

362363
#callexpr = Expr(:call, Core.OpaqueClosure(ir), fn_args...)
363364
#push!(overdubbed_code, callexpr)
@@ -371,23 +372,6 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self,
371372
append!(overdubbed_code, code_info.code)
372373
append!(overdubbed_codelocs, code_info.codelocs)
373374

374-
@show overdubbed_code
375-
376-
@static if false
377-
for i in eachindex(overdubbed_code)
378-
prev = overdubbed_code[i]
379-
if Base.Meta.isexpr(prev, :call)
380-
@show prev
381-
@show prev.args[1]
382-
@show prev.args[1] isa Core.IntrinsicFunction
383-
if !(prev.args[1] isa Core.IntrinsicFunction)
384-
overdubbed_code[i] = Expr(:call, GlobalRef(Reactant, :call_with_reactant), prev.args...)
385-
@show "post", overdubbed_code[i]
386-
end
387-
end
388-
end
389-
end
390-
391375
#=== set `code_info`/`reflection` fields accordingly ===#
392376

393377
if code_info.method_for_inference_limit_heuristics === nothing
@@ -399,58 +383,7 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self,
399383
code_info.ssavaluetypes = length(overdubbed_code)
400384
code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code
401385

402-
@show code_info
403386
return code_info
404-
405-
self_result = Core.Compiler.InferenceResult(self_mi, Core.Compiler.typeinf_lattice(interp))
406-
407-
408-
@show self
409-
self_meths = Base._methods_by_ftype(Tuple{self, Vararg{Any}}, -1, world)
410-
@show self_meths
411-
self_method = (self_meths[1]::Core.MethodMatch).method
412-
self_mi = Core.Compiler.specialize_method(self_method, Tuple{typeof(Reactant.call_with_reactant), sig.parameters...}, Core.svec())
413-
@show self_mi
414-
self_result = Core.Compiler.InferenceResult(self_mi, Core.Compiler.typeinf_lattice(interp))
415-
frame = Core.Compiler.InferenceState(self_result, code_info, #=cache_mode=#:global, interp)
416-
@assert frame !== nothing
417-
Core.Compiler.typeinf(interp, frame)
418-
@assert Core.Compiler.is_inferred(frame)
419-
420-
#if Core.Compiler.result_is_constabi(interp, frame.result)
421-
# rt = frame.result.result::Core.Compiler.Const
422-
# src = Core.Compiler.codeinfo_for_const(interp, frame.linfo, rt.val)
423-
#else
424-
opt = Core.Compiler.OptimizationState(frame, interp)
425-
426-
ir = opt.src
427-
@show ir
428-
for (i, stmt) in enumerate(ir.stmts)
429-
@show stmt
430-
431-
end
432-
433-
@show ir
434-
435-
caller = frame.result
436-
@static if VERSION < v"1.11-"
437-
ir = Core.Compiler.run_passes(ir, opt, caller)
438-
else
439-
ir = Core.Compiler.run_passes_ipo_safe(ir, opt, caller)
440-
Core.Compiler.ipo_dataflow_analysis!(interp, opt, ir, caller)
441-
442-
end
443-
Core.Compiler.finish(interp, opt, ir, caller)
444-
445-
src = Core.Compiler.ir_to_codeinf!(opt)
446-
#end
447-
448-
src = copy(src)
449-
src.ssavaluetypes = length(src.code)
450-
451-
@show src
452-
453-
return src
454387
end
455388

456389
@eval function call_with_reactant(redub_arguments...)

0 commit comments

Comments
 (0)