Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions DifferentiationInterface/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.7...main)

### Fixed

- Handle constant derivatives with runtime activity for Enzyme

## [0.7.7](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.6...DifferentiationInterface-v0.7.7)

### Fixed

- Improve support for empty inputs (still not guaranteed) ([#835](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/835))

## [0.7.6](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.5...DifferentiationInterface-v0.7.6)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module DifferentiationInterfaceEnzymeExt

using ADTypes: ADTypes, AutoEnzyme
using Base: Fix1
using Base: Fix1, datatype_pointerfree
import DifferentiationInterface as DI
using EnzymeCore:
Active,
Expand Down Expand Up @@ -42,7 +42,8 @@ using Enzyme:
jacobian,
make_zero,
make_zero!,
onehot
onehot,
runtime_activity

DI.check_available(::AutoEnzyme) = true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ function DI.value_and_pushforward(
x_and_dx = Duplicated(x, dx)
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1))
dy, y = autodiff(mode, f_and_df, x_and_dx, annotated_contexts...)
dy = runtime_activity_safeguard(backend, y, dy)
return y, (dy,)
end

Expand All @@ -54,8 +55,10 @@ function DI.value_and_pushforward(
f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B))
x_and_tx = BatchDuplicated(x, tx)
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B))
ty, y = autodiff(mode, f_and_df, x_and_tx, annotated_contexts...)
return y, values(ty)
ty_nt, y = autodiff(mode, f_and_df, x_and_tx, annotated_contexts...)
ty = values(ty_nt)
ty = runtime_activity_safeguard(backend, y, ty)
return y, ty
end

function DI.pushforward(
Expand All @@ -66,6 +69,9 @@ function DI.pushforward(
tx::NTuple{1},
contexts::Vararg{DI.Context,C},
) where {F,C}
if has_runtime_activity(backend)
return DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)[2]
end
DI.check_prep(f, prep, backend, x, tx, contexts...)
(; df, context_shadows) = prep
mode = forward_noprimal(backend)
Expand All @@ -85,14 +91,18 @@ function DI.pushforward(
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
if has_runtime_activity(backend)
return DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)[2]
end
DI.check_prep(f, prep, backend, x, tx, contexts...)
(; df, context_shadows) = prep
mode = forward_noprimal(backend)
f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B))
x_and_tx = BatchDuplicated(x, tx)
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B))
ty = only(autodiff(mode, f_and_df, x_and_tx, annotated_contexts...))
return values(ty)
ty_nt = only(autodiff(mode, f_and_df, x_and_tx, annotated_contexts...))
ty = values(ty_nt)
return ty
end

function DI.value_and_pushforward!(
Expand Down Expand Up @@ -168,7 +178,9 @@ function DI.gradient(
derivs = gradient(
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=basis_shadows
)
return first(derivs)
deriv = first(derivs)
deriv = runtime_activity_safeguard(backend, x, deriv)
return deriv
end

function DI.value_and_gradient(
Expand All @@ -186,7 +198,9 @@ function DI.value_and_gradient(
(; derivs, val) = gradient(
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=basis_shadows
)
return val, first(derivs)
deriv = first(derivs)
deriv = runtime_activity_safeguard(backend, x, deriv)
return val, deriv
end

function DI.gradient!(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ function seeded_autodiff_thunk(
) where {ReturnPrimal,FA<:Annotation,RA<:Annotation,N}
forward, reverse = autodiff_thunk(rmode, FA, RA, typeof.(args)...)
tape, result, shadow_result = forward(f, args...)
shadow_result = runtime_activity_safeguard(rmode, result, shadow_result)
if RA <: Active
dinputs = only(reverse(f, args..., dresult, tape))
else
Expand All @@ -30,6 +31,7 @@ function batch_seeded_autodiff_thunk(
rmode_rightwidth = ReverseSplitWidth(rmode, Val(B))
forward, reverse = autodiff_thunk(rmode_rightwidth, FA, RA, typeof.(args)...)
tape, result, shadow_results = forward(f, args...)
shadow_results = runtime_activity_safeguard(rmode_rightwidth, result, shadow_results)
if RA <: Active
dinputs = only(reverse(f, args..., dresults, tape))
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,39 @@ end

batchify_activity(::Type{Active{T}}, ::Val{B}) where {T,B} = Active{T}
batchify_activity(::Type{Duplicated{T}}, ::Val{B}) where {T,B} = BatchDuplicated{T,B}

has_runtime_activity(mode::Mode) = runtime_activity(mode)
has_runtime_activity(::AutoEnzyme{Nothing}) = false
has_runtime_activity(backend::AutoEnzyme{<:Mode}) = has_runtime_activity(backend.mode)

function runtime_activity_safeguard(
backend_or_mode::Union{<:AutoEnzyme,<:Mode}, primal::T, shadow::T
) where {T}
# TODO: improve datatype_pointerfree to take Ptr into account
if has_runtime_activity(backend_or_mode) &&
!datatype_pointerfree(T) &&
pointer(primal) === pointer(shadow) # TODO: doesn't work beyond arrays
return make_zero(shadow)
else
return shadow
end
end

function runtime_activity_safeguard(
backend_or_mode::Union{<:AutoEnzyme,<:Mode},
primal::T,
shadow::Union{NTuple{N,T},NamedTuple},
) where {T,N}
# TODO: improve datatype_pointerfree to take Ptr into account
if has_runtime_activity(backend_or_mode) &&
!datatype_pointerfree(T) &&
pointer(primal) === pointer(shadow[1]) # TODO: doesn't work beyond arrays
return make_zero(shadow)
else
return shadow
end
end

function runtime_activity_safeguard(::Union{<:AutoEnzyme,<:Mode}, primal, shadow::Nothing)
return nothing
end
16 changes: 16 additions & 0 deletions DifferentiationInterface/test/Back/Enzyme/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,19 @@ end
excluded=[:jacobian],
)
end;

@testset "Runtime activity" begin
# TODO: higher-level operators not tested
test_differentiation(
AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Forward)),
DIT.unknown_activity(default_scenarios());
excluded=vcat(SECOND_ORDER, :jacobian, :gradient, :derivative, :pullback),
logging=LOGGING,
)
test_differentiation(
AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse)),
DIT.unknown_activity(default_scenarios());
excluded=vcat(SECOND_ORDER, :jacobian, :gradient, :derivative, :pushforward),
logging=LOGGING,
)
end
49 changes: 49 additions & 0 deletions DifferentiationInterfaceTest/src/scenarios/modify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,54 @@ function closurify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
)
end

struct UnknownActivityReturn{pl_fun,F}
f::F
end

function Base.show(io::IO, f::UnknownActivityReturn)
return print(io, "UnknownActivityReturn($(f.f))")
end

function (f::UnknownActivityReturn{:out})(x, yc, return_constant::Bool)
if return_constant
return copy(yc)
else
return f.f(x)
end
end

function (f::UnknownActivityReturn{:in})(y, x, yc, return_constant::Bool)
if return_constant
copyto!(y, copy(yc))
else
f.f(y, x)
end
return nothing
end

"""
unknown_activity(scen::Scenario)

Return a new scenario identical to `scen` except that the function now takes an additional constant argument which is the theoretical output, and a constant boolean condition stating whether or not that output should be recomputed.
"""
function unknown_activity(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
(; f) = deepcopy(scen)
zero_scen = deepcopy(zero(scen))
@assert isempty(scen.contexts)
unknown_f = UnknownActivityReturn{pl_fun,typeof(f)}(f)
return Scenario{op,pl_op,pl_fun}(;
f=unknown_f,
x=scen.x,
y=scen.y,
t=scen.t,
contexts=(Constant(scen.y), Constant(true)),
res1=zero_scen.res1,
res2=zero_scen.res2,
prep_args=(; scen.prep_args..., contexts=(Constant(scen.y), Constant(true))),
name=isnothing(scen.name) ? nothing : scen.name * " [unknown activity]",
)
end

struct MultiplyByConstant{pl_fun,F} <: FunctionModifier
f::F
end
Expand Down Expand Up @@ -366,6 +414,7 @@ closurify(scens::AbstractVector{<:Scenario}) = closurify.(scens)
constantify(scens::AbstractVector{<:Scenario}) = constantify.(scens)
cachify(scens::AbstractVector{<:Scenario}; use_tuples) = cachify.(scens; use_tuples)
constantorcachify(scens::AbstractVector{<:Scenario}) = constantorcachify.(scens)
unknown_activity(scens::AbstractVector{<:Scenario}) = unknown_activity.(scens)

## Compute results with backend

Expand Down
7 changes: 7 additions & 0 deletions DifferentiationInterfaceTest/test/weird.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ test_differentiation(
logging=LOGGING,
);

test_differentiation(
AutoFiniteDiff(),
unknown_activity(default_scenarios);
excluded=SECOND_ORDER,
logging=LOGGING,
);

## Neural nets

test_differentiation(
Expand Down
Loading