Skip to content

Use Enzyme or DifferentiationInterface for autodiff #300

@pat-alt

Description

@pat-alt

Current Status

The AD part of the package has not undergone any major overhaul since I first implemented it around the start of the project. Back then I relied on Flux/Zygote, because the package was tailored to Flux models anyway (at the time), I was entirely new to AD and Julia; and I could use Zygote to differentiate through structs, like so:

"""
    ∂ℓ(
        generator::AbstractGradientBasedGenerator,
        ce::AbstractCounterfactualExplanation,
    )

The default method to compute the gradient of the loss function at the current counterfactual state for gradient-based generators.
It assumes that `Zygote.jl` has gradient access.
"""
function ∂ℓ(
    generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation
)
    return Flux.gradient(ce -> (generator, ce), ce)[1][:counterfactual_state]
end

Pain Points

The current implementation is less than ideal for various reasons:

  • Zygote cannot handle nested AD, which is necessary for some counterfactual generators (see Sort out PROBE #376).
  • Gradients are still taken implicitly, which is not in line with where the broader ecosystem is headed, I believe.
  • The previous point also makes it difficult to implement forward-over-reverse to solve the nested AD issue.
  • The AD implementation has never been optimized for performance, so I guess there's a lot of room for improvement here.

To Do

Metadata

Metadata

Assignees

No one assigned

    Labels

    difficultThis is expected to be difficult.enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions