Skip to content

Commit f172791

Browse files
Merge pull request #125 from ChrisRackauckas-Claude/compat-functors-0.5
Bump Functors compat to 0.5 and fix Zygote 0.7 compatibility
2 parents 8a29779 + 266bc59 commit f172791

File tree

3 files changed

+5
-4
lines changed

3 files changed

+5
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ DiffEqBase = "6.151"
2929
Distributions = "v0.25.107"
3030
DocStringExtensions = "0.9.3"
3131
Flux = "0.14.16, 0.15, 0.16"
32-
Functors = "0.4.11"
32+
Functors = "0.4.11, 0.5"
3333
LinearAlgebra = "1.10"
3434
Random = "1.10"
3535
Reexport = "1.2.2"

src/DeepBSDE.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ function DiffEqBase.solve(
165165
)
166166
return map(sol) do _sol
167167
predict_ans = Array(_sol)
168-
(predict_ans[1:(end - 1), end], predict_ans[end, end])
168+
predict_ans[:, end]
169169
end
170170
end
171171

@@ -176,7 +176,8 @@ function DiffEqBase.solve(
176176
end
177177

178178
function loss_n_sde()
179-
return mean(sum(abs2, g(X) - u) for (X, u) in predict_n_sde())
179+
preds = predict_n_sde()
180+
return mean(sum(abs2, g(pred[1:(end - 1)]) - pred[end]) for pred in preds)
180181
end
181182

182183
iters = eltype(x0)[]

src/NNStopping.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ struct NNStoppingModelArray{M}
4747
ms::M
4848
end
4949

50-
Flux.@functor NNStoppingModelArray
50+
@functor NNStoppingModelArray
5151

5252
function (model::NNStoppingModelArray)(X, G)
5353
XG = cat(X, reshape(G, 1, size(G)...), dims = 1)

0 commit comments

Comments
 (0)