Skip to content

Commit e0a3ec4

Browse files
author
William Moses
committed
more work
1 parent 7c7c8ed commit e0a3ec4

File tree

6 files changed

+105
-115
lines changed

6 files changed

+105
-115
lines changed

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ version = "0.2.9"
66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
9-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
109
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
1110
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1211
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"

ext/ReactantCUDAExt.jl

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,8 @@ function compiler_cache(ctx::MLIR.IR.Context)
437437
return cache
438438
end
439439

440-
function recufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
440+
Reactant.@overlay function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
441+
@show "recufunction", f, tt
441442
res = Base.@lock CUDA.cufunction_lock begin
442443
# compile the function
443444
cache = compiler_cache(MLIR.IR.context())
@@ -450,43 +451,4 @@ function recufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
450451
res
451452
end
452453

453-
const CC = Core.Compiler
454-
455-
import Core.Compiler:
456-
AbstractInterpreter,
457-
abstract_call,
458-
abstract_call_known,
459-
ArgInfo,
460-
StmtInfo,
461-
AbsIntState,
462-
get_max_methods,
463-
CallMeta,
464-
Effects,
465-
NoCallInfo,
466-
widenconst,
467-
mapany,
468-
MethodResultPure
469-
470-
471-
function Reactant.set_reactant_abi(
472-
interp,
473-
f::typeof(CUDA.cufunction),
474-
arginfo::ArgInfo,
475-
si::StmtInfo,
476-
sv::AbsIntState,
477-
max_methods::Int=get_max_methods(interp, f, sv),
478-
)
479-
(; fargs, argtypes) = arginfo
480-
481-
arginfo2 = ArgInfo(
482-
if fargs isa Nothing
483-
nothing
484-
else
485-
[:($(recufunction)), fargs[2:end]...]
486-
end,
487-
[Core.Const(recufunction), argtypes[2:end]...],
488-
)
489-
return abstract_call_known(interp, recufunction, arginfo2, si, sv, max_methods)
490-
end
491-
492454
end # module ReactantCUDAExt

src/Interpreter.jl

Lines changed: 11 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,15 @@ import Core.Compiler:
2121
mapany,
2222
MethodResultPure
2323

24+
25+
Base.Experimental.@MethodTable REACTANT_METHOD_TABLE
26+
27+
macro overlay(method_expr)
28+
def = splitdef(method_expr)
29+
def[:name] = Expr(:overlay, :(Reactant.REACTANT_METHOD_TABLE), def[:name])
30+
return esc(combinedef(def))
31+
end
32+
2433
function set_reactant_abi(
2534
interp,
2635
@nospecialize(f),
@@ -54,50 +63,6 @@ function set_reactant_abi(
5463
end
5564
end
5665

57-
if length(argtypes) >= 5 &&
58-
f === Core.kwcall &&
59-
(
60-
widenconst(argtypes[3]) == typeof(Enzyme.gradient) ||
61-
widenconst(argtypes[3]) == typeof(Enzyme.jacobian)
62-
) &&
63-
widenconst(argtypes[4]) <: Enzyme.Mode
64-
newmode = Enzyme.set_abi(widenconst(argtypes[4]), ReactantABI)
65-
if newmode != widenconst(argtypes[4])
66-
newmodev = newmode()
67-
arginfo2 = ArgInfo(
68-
if fargs isa Nothing
69-
nothing
70-
else
71-
[fargs[1:3]..., :($(newmodev)), fargs[5:end]...]
72-
end,
73-
[argtypes[1:3]..., Core.Const(newmodev), argtypes[5:end]...],
74-
)
75-
return abstract_call_known(interp, f, arginfo2, si, sv, max_methods)
76-
end
77-
end
78-
79-
if length(argtypes) >= 5 &&
80-
methods(f)[1].module == Enzyme &&
81-
widenconst(argtypes[5]) <: Enzyme.Mode &&
82-
(
83-
widenconst(argtypes[4]) == typeof(Enzyme.gradient) ||
84-
widenconst(argtypes[4]) == typeof(Enzyme.jacobian)
85-
)
86-
newmode = Enzyme.set_abi(widenconst(argtypes[5]), ReactantABI)
87-
if newmode != widenconst(argtypes[5])
88-
newmodev = newmode()
89-
arginfo2 = ArgInfo(
90-
if fargs isa Nothing
91-
nothing
92-
else
93-
[fargs[1:4]..., :($(newmodev)), fargs[6:end]...]
94-
end,
95-
[argtypes[1:4]..., Core.Const(newmodev), argtypes[6:end]...],
96-
)
97-
return abstract_call_known(interp, f, arginfo2, si, sv, max_methods)
98-
end
99-
end
100-
10166
return Base.@invoke abstract_call_known(
10267
interp::AbstractInterpreter,
10368
f::Any,
@@ -116,7 +81,7 @@ function set_reactant_abi end
11681
function ReactantInterpreter(; world::UInt=Base.get_world_counter())
11782
return Enzyme.Compiler.Interpreter.EnzymeInterpreter(
11883
ReactantCacheToken(),
119-
nothing, #=mt=#
84+
REACTANT_METHOD_TABLE,
12085
world,
12186
true, #=forward_rules=#
12287
true, #=reverse_rules=#
@@ -132,7 +97,7 @@ else
13297
)
13398
return Enzyme.Compiler.Interpreter.EnzymeInterpreter(
13499
REACTANT_CACHE,
135-
nothing, #=mt=#
100+
REACTANT_METHOD_TABLE,
136101
world,
137102
true, #=forward_rules=#
138103
true, #=forward_rules=#

src/Reactant.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,4 +128,7 @@ function set_default_backend(backend::String)
128128
return set_default_backend(XLA.backends[backend])
129129
end
130130

131+
# include("../ext/ReactantCUDAExt.jl")
132+
131133
end # module
134+

src/utils.jl

Lines changed: 86 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,88 @@ function apply(f, args...; kwargs...)
3737
return f(args...; kwargs...)
3838
end
3939

40+
function call_with_reactant end
41+
42+
function rewrite_inst(inst)
43+
@show inst
44+
if Meta.isexpr(inst, :call)
45+
rep = Expr(:call, call_with_reactant, inst.args...)
46+
@show rep
47+
return rep
48+
end
49+
return inst
50+
end
51+
52+
function call_with_reactant_generator(world::UInt, source::LineNumberNode, @nospecialize(F::Type), @nospecialize(N::Int), self, @nospecialize(f::Type), @nospecialize(args))
53+
@nospecialize
54+
@show f, args
55+
56+
stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, :f, :args), Core.svec())
57+
58+
# look up the method match
59+
method_error = :(throw(MethodError(f, args, $world)))
60+
61+
interp = ReactantInterpreter(; world)
62+
63+
mt = interp.method_table
64+
65+
sig = Tuple{F, args...}
66+
min_world = Ref{UInt}(typemin(UInt))
67+
max_world = Ref{UInt}(typemax(UInt))
68+
match = ccall(:jl_gf_invoke_lookup_worlds, Any,
69+
(Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}),
70+
sig, mt, world, min_world, max_world)
71+
match === nothing && return stub(world, source, method_error)
72+
73+
# look up the method and code instance
74+
mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
75+
(Any, Any, Any), match.method, match.spec_types, match.sparams)
76+
77+
result = Core.Compiler.InferenceResult(mi, Core.Compiler.typeinf_lattice(interp))
78+
frame = Core.Compiler.InferenceState(result, #=cache_mode=#:global, interp)
79+
@assert frame !== nothing
80+
Core.Compiler.typeinf(interp, frame)
81+
@assert Core.Compiler.is_inferred(frame)
82+
83+
#if Core.Compiler.result_is_constabi(interp, frame.result)
84+
# rt = frame.result.result::Core.Compiler.Const
85+
# src = Core.Compiler.codeinfo_for_const(interp, frame.linfo, rt.val)
86+
#else
87+
opt = Core.Compiler.OptimizationState(frame, interp)
88+
caller = frame.result
89+
@static if VERSION < v"1.11-"
90+
ir = Core.Compiler.run_passes(opt.src, opt, caller)
91+
else
92+
ir = Core.Compiler.run_passes_ipo_safe(opt.src, opt, caller)
93+
Core.Compiler.ipo_dataflow_analysis!(interp, opt, ir, caller)
94+
end
95+
@show ir
96+
for (i, inst) in enumerate(ir.stmts)
97+
@static if VERSION < v"1.11"
98+
Core.Compiler.setindex!(ir.stmts[i], rewrite_inst(inst[:inst]), :inst)
99+
else
100+
Core.Compiler.setindex!(ir.stmts[i], rewrite_inst(inst[:stmt]), :stmt)
101+
end
102+
end
103+
@show ir
104+
Core.Compiler.finish(interp, opt, ir, caller)
105+
src = Core.Compiler.ir_to_codeinf!(opt)
106+
#end
107+
108+
new_ci = copy(src)
109+
new_ci.slotnames = Symbol[Symbol("#self#"), :f, :args]
110+
new_ci.edges = Core.MethodInstance[mi]
111+
new_ci.min_world = min_world[]
112+
new_ci.max_world = max_world[]
113+
114+
return new_ci
115+
end
116+
117+
@eval function call_with_reactant(f::F, args::Vararg{Any, N}) where {F, N}
118+
$(Expr(:meta, :generated_only))
119+
$(Expr(:meta, :generated, call_with_reactant_generator))
120+
end
121+
40122
function make_mlir_fn(
41123
f,
42124
args,
@@ -131,36 +213,13 @@ function make_mlir_fn(
131213
interp = ReactantInterpreter()
132214

133215
# TODO replace with `Base.invoke_within` if julia#52964 lands
134-
# TODO fix it for kwargs
135-
ircoderes = Base.code_ircode(f, map(typeof, traced_args); interp)
136-
137-
if length(ircoderes) != 1
138-
throw(
139-
AssertionError(
140-
"Could not find unique ircode for $f $traced_args, found $ircoderes"
141-
),
142-
)
143-
end
144-
ir, ty = ircoderes[1]
145-
oc = Core.OpaqueClosure(ir)
216+
# TODO fix it for kwargs
217+
oc = call_with_reactant # Core.OpaqueClosure(ir)
146218

147219
if f === Reactant.apply
148-
oc(traced_args[1], (traced_args[2:end]...,))
220+
oc(f, traced_args[1], (traced_args[2:end]...,))
149221
else
150-
if (length(traced_args) + 1 != length(ir.argtypes)) || (
151-
length(traced_args) > 0 &&
152-
length(ir.argtypes) > 0 &&
153-
!(last(ir.argtypes) isa Core.Const) &&
154-
last(ir.argtypes) != typeof(traced_args[end])
155-
)
156-
@assert ir.argtypes[end] <: Tuple
157-
oc(
158-
traced_args[1:(length(ir.argtypes) - 2)]...,
159-
(traced_args[(length(ir.argtypes) - 1):end]...,),
160-
)
161-
else
162-
oc(traced_args...)
163-
end
222+
oc(f, traced_args...)
164223
end
165224
end
166225

test/cuda.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ end
1111

1212
# basic squaring on GPU
1313
function square!(x)
14-
@cuda blocks = 1 threads = length(x) square_kernel!(x)
14+
# @cuda blocks = 1 threads = length(x) square_kernel!(x)
15+
cr = @cuda launch=false square_kernel!(x)
16+
@show cr
1517
return nothing
1618
end
1719

0 commit comments

Comments
 (0)