-
-
Notifications
You must be signed in to change notification settings - Fork 81
Open
Description
Description
The Matrix Multiplication ODE gradient tests fail with all reverse-mode AD backends (ReverseDiff, Tracker, Zygote) on all Julia versions.
Test Location
test/concrete_solve_derivatives.jl - "Matrix Multiplication ODE" testset
Current Status
Tests are marked as @test_broken in the test suite.
Test Code
@testset "Matrix Multiplication ODE" begin
solvealg_test = Tsit5()
sensealg_test = InterpolatingAdjoint()
tspan = (0.0, 1.0)
u0_mat = rand(4, 8)
p0 = rand(16)
f_aug(u, p, t) = reshape(p, 4, 4) * u
function loss_mat(p)
prob = ODEProblem(f_aug, u0_mat, tspan, p; alg = solvealg_test, sensealg = sensealg_test)
sol = solve(prob)
return sum(sol[:, :, end])
end
# These tests fail for ReverseDiff, Tracker, and Zygote
res2 = grad_fn(loss_mat, p0)
res4 = grad_fn(loss_mat2, p0)
@test res2 ≈ res4 atol = 1.0e-10
@test res2 ≈ ForwardDiff.gradient(loss_mat, p0) atol = 1.0e-10
endRelated
This issue was identified during the AD backend test refactoring in PR #1322.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels