File tree Expand file tree Collapse file tree 3 files changed +17
-0
lines changed Expand file tree Collapse file tree 3 files changed +17
-0
lines changed Original file line number Diff line number Diff line change @@ -19,6 +19,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1919
2020[weakdeps ]
2121FastBroadcast = " 7034ab61-46d4-4ed7-9d0f-46aef9175898"
22+ ForwardDiff = " f6369f11-7733-5829-9624-2563aa707210"
2223Measurements = " eff96d63-e80a-5855-80a2-b1b0885c5ab7"
2324MonteCarloMeasurements = " 0987c9cc-fe09-11e8-30f0-b96dd679fdca"
2425Tracker = " 9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
@@ -27,6 +28,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2728
2829[extensions ]
2930RecursiveArrayToolsFastBroadcastExt = " FastBroadcast"
31+ RecursiveArrayToolsForwardDiffExt = " ForwardDiff"
3032RecursiveArrayToolsMeasurementsExt = " Measurements"
3133RecursiveArrayToolsMonteCarloMeasurementsExt = " MonteCarloMeasurements"
3234RecursiveArrayToolsReverseDiffExt = [" ReverseDiff" , " Zygote" ]
Original file line number Diff line number Diff line change 1+ module RecursiveArrayToolsForwardDiffExt
2+
3+ using RecursiveArrayTools
4+ using ForwardDiff
5+
6+ function ForwardDiff. extract_derivative (:: Type{T} , y:: AbstractVectorOfArray ) where {T}
7+ ForwardDiff. extract_derivative .(T, y)
8+ end
9+
10+ end
Original file line number Diff line number Diff line change @@ -62,6 +62,10 @@ function loss8(x)
6262 return sum (abs2, res)
6363end
6464
65+ function loss9 (x)
66+ return VectorOfArray ([collect (3 i: 3 i+ 3 ) .* x for i in 1 : 5 ])
67+ end
68+
6569x = float .(6 : 10 )
6670loss (x)
6771@test Zygote. gradient (loss, x)[1 ] == ForwardDiff. gradient (loss, x)
@@ -72,3 +76,4 @@ loss(x)
7276@test Zygote. gradient (loss6, x)[1 ] == ForwardDiff. gradient (loss6, x)
7377@test Zygote. gradient (loss7, x)[1 ] == ForwardDiff. gradient (loss7, x)
7478@test Zygote. gradient (loss8, x)[1 ] == ForwardDiff. gradient (loss8, x)
79+ @test ForwardDiff. derivative (loss9, 0.0 ) == VectorOfArray ([collect (3 i: 3 i+ 3 ) for i in 1 : 5 ])
You can’t perform that action at this time.
0 commit comments