Skip to content

Commit c7afab7

Browse files
author
William Moses
committed
continuing
1 parent 01cfc33 commit c7afab7

File tree

3 files changed

+55
-25
lines changed

3 files changed

+55
-25
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,6 @@ function compile(job)
222222
asm, meta = CUDA.GPUCompiler.compile(:asm, job)
223223
mod = meta.ir
224224
modstr = string(mod)
225-
@show mod
226-
@show modstr
227225
# check if we'll need the device runtime
228226
undefined_fs = filter(collect(CUDA.LLVM.functions(meta.ir))) do f
229227
CUDA.LLVM.isdeclaration(f) && !CUDA.LLVM.isintrinsic(f)
@@ -375,8 +373,7 @@ function compile(job)
375373

376374
modstr, image, meta.entry
377375
end
378-
379-
LLVMFunc{job.source.specTypes[1],job.source.specTypes}(nothing, modstr, image, LLVM.name(entry))
376+
LLVMFunc{job.source.specTypes.parameters[1],job.source.specTypes}(nothing, modstr, image, CUDA.LLVM.name(entry))
380377
end
381378

382379
# link into an executable kernel
@@ -385,20 +382,23 @@ function link(job, compiled)
385382
return compiled
386383
end
387384

388-
function (func::LLVMFunc{F,tt})(args...; blocks::CUDA.CuDim=1, threads::CUDA.CuDim=1,
389-
shmem::Integer=0) where{F, tt}
385+
function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1,
386+
cooperative::Bool=false, shmem::Integer=0, call_kwargs...) where{F, tt}
387+
@show args
388+
@show call_kwargs
389+
390390
blockdim = CUDA.CuDim3(blocks)
391391
threaddim = CUDA.CuDim3(threads)
392392

393-
@show args
394-
395393
mlir_args = MLIR.IR.Value[]
396394
restys = MLIR.IR.Type[]
397395
aliases = MLIR.API.MlirAttribute[]
396+
rarrays = TracedRArray[]
398397
for (i, a) in enumerate(args)
399398
@show a
400-
@assert a isa CuDeviceArray
401-
ta = Base.pointer_to_objref(a.ptr)::TracedRArray
399+
@assert a isa CuTracedArray
400+
ta = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray
401+
push!(rarrays, ta)
402402
arg = ta.mlir_data
403403
arg = Reactant.Compiler.transpose_val(arg)
404404
push!(restys, MLIR.IR.Type(arg))
@@ -415,7 +415,10 @@ function (func::LLVMFunc{F,tt})(args...; blocks::CUDA.CuDim=1, threads::CUDA.CuD
415415
end
416416

417417
output_operand_aliases=MLIR.ArrayAttr.get(MLIR.IR.context(), aliases)
418-
MLIR.IR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases)
418+
call = MLIR.IR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases)
419+
for (i, res) in enumerate(rarrays)
420+
ta.mlir_data = Reactant.Compiler.transpose_val(MLIR.IR.result(call, i-1))
421+
end
419422
#CUDA.cuLaunchKernel(f,
420423
# blockdim.x, blockdim.y, blockdim.z,
421424
# threaddim.x, threaddim.y, threaddim.z,

src/utils.jl

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,13 @@ function rewrite_inst(inst, ir)
6363
ft = Core.Compiler.widenconst(maybe_argextype(inst.args[1], ir))
6464
if !(ft <: Core.IntrinsicFunction) && !(ft <: Core.Builtin)
6565
rep = Expr(:call, call_with_reactant, inst.args...)
66-
return rep
66+
return true, rep
6767
end
6868
end
6969
if Meta.isexpr(inst, :invoke)
70-
return Expr(:call, inst.args[2:end]...)
70+
return false, Expr(:call, inst.args[2:end]...)
7171
end
72-
return inst
72+
return false, inst
7373
end
7474

7575
const REDUB_ARGUMENTS_NAME = gensym("redub_arguments")
@@ -120,10 +120,14 @@ function _arg_partially_inline!(@nospecialize(x), slot_replacements::Vector{Any}
120120
return x
121121
end
122122
if isa(x, Core.ReturnNode)
123-
return Core.ReturnNode(
123+
if !isdefined(x, :val)
124+
return Core.ReturnNode(:nothing)
125+
else
126+
return Core.ReturnNode(
124127
_arg_partially_inline!(x.val, slot_replacements, type_signature, static_param_values,
125128
slot_offset, arg_offset, statement_offset, boundscheck),
126-
)
129+
)
130+
end
127131
end
128132
if isa(x, Core.GotoIfNot)
129133
return Core.GotoIfNot(
@@ -257,12 +261,19 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self,
257261
Core.Compiler.typeinf(interp, frame)
258262
@assert Core.Compiler.is_inferred(frame)
259263

264+
method = match.method
265+
@show mi
266+
@show method
267+
260268
#if Core.Compiler.result_is_constabi(interp, frame.result)
261269
# rt = frame.result.result::Core.Compiler.Const
262270
# src = Core.Compiler.codeinfo_for_const(interp, frame.linfo, rt.val)
263271
#else
264272
opt = Core.Compiler.OptimizationState(frame, interp)
265273

274+
@show Core.Compiler.retrieve_code_info(mi, world)
275+
@show opt.src
276+
266277
caller = frame.result
267278
@static if VERSION < v"1.11-"
268279
ir = Core.Compiler.run_passes(opt.src, opt, caller)
@@ -271,21 +282,35 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self,
271282
Core.Compiler.ipo_dataflow_analysis!(interp, opt, ir, caller)
272283
end
273284

274-
for (i, inst) in enumerate(ir.stmts)
285+
@show ir
286+
any_changed = false
287+
for (i, inst) in enumerate(ir.stmts)
275288

289+
276290
@static if VERSION < v"1.11"
277-
Core.Compiler.setindex!(ir.stmts[i], rewrite_inst(inst[:inst], ir), :inst)
291+
changed, next = rewrite_inst(inst[:inst], ir)
292+
Core.Compiler.setindex!(ir.stmts[i], next, :inst)
278293
else
279-
Core.Compiler.setindex!(ir.stmts[i], rewrite_inst(inst[:stmt], ir), :stmt)
294+
changed, next = rewrite_inst(inst[:stmt], ir)
295+
Core.Compiler.setindex!(ir.stmts[i], next, :stmt)
280296
end
281-
Core.Compiler.setindex!(ir.stmts[i], Any, :type)
297+
if changed
298+
any_changed = true
299+
Core.Compiler.setindex!(ir.stmts[i], Any, :type)
300+
end
282301
end
283302
Core.Compiler.finish(interp, opt, ir, caller)
303+
@show "post", ir
284304
src = Core.Compiler.ir_to_codeinf!(opt)
305+
306+
@show any_changed, src
307+
if !any_changed
308+
src = Core.Compiler.retrieve_code_info(mi, world)
309+
@show "post non change", src
310+
end
285311

286312
# prepare a new code info
287313
code_info = copy(src)
288-
method = match.method
289314
static_params = match.sparams
290315
signature = sig
291316
is_invoke = args[1] === typeof(Core.invoke)
@@ -352,6 +377,7 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self,
352377
push!(fn_args, Core.SSAValue(length(overdubbed_code)))
353378
end
354379

380+
@show code_info.code
355381
#=== finish initialization of `overdubbed_code`/`overdubbed_codelocs` ===#
356382

357383
# substitute static parameters, offset slot numbers by number of added slots, and
@@ -383,6 +409,8 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self,
383409
code_info.ssavaluetypes = length(overdubbed_code)
384410
code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code
385411

412+
@show code_info
413+
386414
return code_info
387415
end
388416

test/cuda.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,14 @@ end
1111

1212
# basic squaring on GPU
1313
function square!(x)
14-
# @cuda blocks = 1 threads = length(x) square_kernel!(x)
15-
cr = @cuda launch=false square_kernel!(x)
16-
@show cr
14+
@cuda blocks = 1 threads = length(x) square_kernel!(x)
1715
return nothing
1816
end
1917

2018
@testset "Square Kernel" begin
2119
oA = collect(1:1:64)
2220
A = Reactant.to_rarray(oA)
21+
@show @code_hlo square!(A)
2322
func = @compile square!(A)
24-
@test all(A .≈ (oA .* oA))
23+
@test all(Array(A) .≈ (oA .* oA))
2524
end

0 commit comments

Comments
 (0)