Skip to content

Commit 0c61f5d

Browse files
author
William Moses
committed
fix
1 parent 8a0bfb0 commit 0c61f5d

File tree

2 files changed

+17
-86
lines changed

2 files changed

+17
-86
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module ReactantCUDAExt
22

33
using CUDA
44
using Reactant:
5-
Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber
5+
Reactant, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber
66
using ReactantCore: @trace
77

88
using Adapt
@@ -465,7 +465,7 @@ function transpose_val(val)
465465
return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1)
466466
end
467467

468-
function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1,
468+
Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1,
469469
cooperative::Bool=false, shmem::Integer=0, call_kwargs...) where{F, tt}
470470
@show args
471471
@show call_kwargs
@@ -522,7 +522,7 @@ function compiler_cache(ctx::MLIR.IR.Context)
522522
return cache
523523
end
524524

525-
Reactant.@reactant_override function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
525+
Reactant.@reactant_override @noinline function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
526526
@show "recufunction", f, tt
527527
res = Base.@lock CUDA.cufunction_lock begin
528528
# compile the function
@@ -543,6 +543,7 @@ Reactant.@reactant_override function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwarg
543543
config = CUDA.CompilerConfig(CUDA.PTXCompilerTarget(; cap=llvm_cap, ptx=llvm_ptx, debuginfo), CUDA.CUDACompilerParams(; cap=cuda_cap, ptx=cuda_ptx); kernel, name, always_inline)
544544
CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link)
545545
end
546+
@show res
546547
res
547548
end
548549

src/utils.jl

Lines changed: 13 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -307,78 +307,6 @@ function call_with_reactant_generator(
307307
# No method could be found (including in our method table), bail with an error
308308
if lookup_result == nothing
309309
return stub(world, source, method_error)
310-
tmp_min_world = Ref{UInt}(typemin(UInt))
311-
tmp_max_world = Ref{UInt}(typemax(UInt))
312-
match = ccall(
313-
:jl_gf_invoke_lookup_worlds,
314-
Any,
315-
(Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}),
316-
Tuple{typeof(throw_method_error),sig},
317-
nothing,
318-
world,
319-
tmp_min_world,
320-
tmp_max_world,
321-
) #=mt=#
322-
@assert match !== nothing
323-
324-
# look up the method and code instance
325-
mi = ccall(
326-
:jl_specializations_get_linfo,
327-
Ref{Core.MethodInstance},
328-
(Any, Any, Any),
329-
match.method,
330-
match.spec_types,
331-
match.sparams,
332-
)
333-
334-
ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo
335-
336-
src = copy(ci)
337-
src.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME]
338-
339-
src.edges = Any[
340-
ccall(:jl_method_table_for, Any, (Any,), sig)::Core.MethodTable, sig
341-
]
342-
src.min_world = min_world[]
343-
src.max_world = max_world[]
344-
345-
push!(overdubbed_code, :($(Base.getindex)($(Core.Argument(2)), 1)))
346-
push!(overdubbed_codelocs, 0)
347-
348-
expr_fn = Core.SSAValue(length(overdubbed_code))
349-
350-
push!(overdubbed_code, :($(Base.lastindex)($(Core.Argument(2)))))
351-
push!(overdubbed_codelocs, 0)
352-
353-
expr_lastindex = Core.SSAValue(length(overdubbed_code))
354-
355-
push!(overdubbed_code, :(2:($expr_lastindex)))
356-
push!(overdubbed_codelocs, 0)
357-
358-
expr_slice = Core.SSAValue(length(overdubbed_code))
359-
360-
push!(overdubbed_code, :($(Base.getindex)($(Core.Argument(2)), $expr_slice)))
361-
push!(overdubbed_codelocs, 0)
362-
363-
expr_args = Core.SSAValue(length(overdubbed_code))
364-
365-
push!(overdubbed_code, :($(Base.MethodError)($expr_fn, $expr_args, $world)))
366-
push!(overdubbed_codelocs, 0)
367-
368-
expr_method = Core.SSAValue(length(overdubbed_code))
369-
370-
push!(overdubbed_code, :($(Base.throw)($expr_method)))
371-
push!(overdubbed_codelocs, 0)
372-
373-
push!(overdubbed_code, Core.ReturnNode(Core.SSAValue(length(overdubbed_code))))
374-
push!(overdubbed_codelocs, 0)
375-
376-
src.code = overdubbed_code
377-
src.codelocs = overdubbed_codelocs
378-
src.ssavaluetypes = length(overdubbed_code)
379-
src.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code
380-
381-
return src
382310
end
383311

384312
match = lookup_result::Core.MethodMatch
@@ -438,17 +366,19 @@ function call_with_reactant_generator(
438366
# Also rewrite invoke (type stable call) to be :call, since otherwise apparently
439367
# screws up type inference after this (TODO this should be fixed).
440368
any_changed = false
441-
for (i, inst) in enumerate(ir.stmts)
442-
@static if VERSION < v"1.11"
443-
changed, next = rewrite_inst(inst[:inst], ir, interp)
444-
Core.Compiler.setindex!(ir.stmts[i], next, :inst)
445-
else
446-
changed, next = rewrite_inst(inst[:stmt], ir, interp)
447-
Core.Compiler.setindex!(ir.stmts[i], next, :stmt)
448-
end
449-
if changed
450-
any_changed = true
451-
Core.Compiler.setindex!(ir.stmts[i], Any, :type)
369+
if should_rewrite_ft(args[1]) && !is_reactant_method(mi)
370+
for (i, inst) in enumerate(ir.stmts)
371+
@static if VERSION < v"1.11"
372+
changed, next = rewrite_inst(inst[:inst], ir, interp)
373+
Core.Compiler.setindex!(ir.stmts[i], next, :inst)
374+
else
375+
changed, next = rewrite_inst(inst[:stmt], ir, interp)
376+
Core.Compiler.setindex!(ir.stmts[i], next, :stmt)
377+
end
378+
if changed
379+
any_changed = true
380+
Core.Compiler.setindex!(ir.stmts[i], Any, :type)
381+
end
452382
end
453383
end
454384

0 commit comments

Comments
 (0)