Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
steps:
- group: ":test_tube: Tests"
steps:
- label: "CUDA Julia v{{matrix.version}} -- {{matrix.group}}"
- label: ":julia: :linux: CUDA Julia v{{matrix.version}} -- {{matrix.group}}"
matrix:
setup:
version:
Expand Down Expand Up @@ -33,7 +33,7 @@ steps:
env:
REACTANT_TEST_GROUP: "{{matrix.group}}"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 60
timeout_in_minutes: 120

- label: ":julia: :linux: aarch64 - Julia v{{matrix.version}} -- {{matrix.group}}"
matrix:
Expand Down Expand Up @@ -70,7 +70,7 @@ steps:
env:
REACTANT_TEST_GROUP: "{{matrix.group}}"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 60
timeout_in_minutes: 120

- group: ":racehorse: Benchmarks"
steps:
Expand Down
7 changes: 0 additions & 7 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -779,13 +779,6 @@ function compile(f, args; client=nothing, optimize=true, sync=false)
return register_thunk(fname, body)
end

# Compiling within a compile should return simply the original function
Reactant.@reactant_override function Reactant.Compiler.compile(
f, args; client=nothing, optimize=true, sync=false
)
return f
end

# inspired by RuntimeGeneratedFunction.jl
const __thunk_body_cache = Dict{Symbol,Expr}()

Expand Down
14 changes: 1 addition & 13 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import Core.Compiler:

Base.Experimental.@MethodTable(REACTANT_METHOD_TABLE)

function var"@reactant_override"(__source__::LineNumberNode, __module__::Module, def)
function var"@reactant_overlay"(__source__::LineNumberNode, __module__::Module, def)
return Base.Experimental.var"@overlay"(
__source__, __module__, :(Reactant.REACTANT_METHOD_TABLE), def
)
Expand Down Expand Up @@ -479,15 +479,3 @@ function overload_autodiff(
end
end
end

@reactant_override @noinline function Enzyme.autodiff_deferred(
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
end

@reactant_override @noinline function Enzyme.autodiff(
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
end
24 changes: 24 additions & 0 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# NOTE: We are placing all the reactant_overrides here to avoid incompatibilities with
# Revise.jl. Essentially files that contain reactant_overrides cannot be revised
# correctly. Once that (https://github.com/timholy/Revise.jl/issues/646) is resolved
# we should move all the reactant_overrides to relevant files.

# Compiling within a compile should return simply the original function
@reactant_overlay function Compiler.compile(
f, args; client=nothing, optimize=true, sync=false
)
return f
end

# Enzyme overrides
@reactant_overlay @noinline function Enzyme.autodiff_deferred(
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
end

@reactant_overlay @noinline function Enzyme.autodiff(
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
end
2 changes: 2 additions & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ include("ControlFlow.jl")
include("Tracing.jl")
include("Compiler.jl")

include("Overlay.jl")

function Enzyme.make_zero(
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
)::RT where {copy_if_inactive,RT<:RArray}
Expand Down
Loading