-
Notifications
You must be signed in to change notification settings - Fork 6
Closed
Labels
difficultThis is expected to be difficult.This is expected to be difficult.enhancementNew feature or requestNew feature or request
Description
Current Status
The AD part of the package has not undergone any major overhaul since I first implemented it around the start of the project. Back then I relied on Flux/Zygote, because the package was tailored to Flux models anyway (at the time), I was entirely new to AD and Julia; and I could use Zygote to differentiate through structs, like so:
"""
∂ℓ(
generator::AbstractGradientBasedGenerator,
ce::AbstractCounterfactualExplanation,
)
The default method to compute the gradient of the loss function at the current counterfactual state for gradient-based generators.
It assumes that `Zygote.jl` has gradient access.
"""
function ∂ℓ(
generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation
)
return Flux.gradient(ce -> ℓ(generator, ce), ce)[1][:counterfactual_state]
endPain Points
The current implementation is less than ideal for various reasons:
- Zygote cannot handle nested AD, which is necessary for some counterfactual generators (see Sort out PROBE #376).
- Gradients are still taken implicitly, which is not in line with where the broader ecosystem is headed, I believe.
- The previous point also makes it difficult to implement forward-over-reverse to solve the nested AD issue.
- The AD implementation has never been optimized for performance, so I guess there's a lot of room for improvement here.
To Do
- Double-check https://gdalle.github.io/JuliaCon2024-AutoDiff/#/title-slide
- Try out DifferentiationInterface.jl
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
difficultThis is expected to be difficult.This is expected to be difficult.enhancementNew feature or requestNew feature or request