Skip to content

Commit d2bfd1f

Browse files
authored
Add JuliaRL_DQN_CartPole (#650)
* add back common networks * add TwinNetwork * sync * add experiment JuliaRL_DQN_CartPole
1 parent c67a604 commit d2bfd1f

File tree

13 files changed

+510
-188
lines changed

13 files changed

+510
-188
lines changed

Project.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@ ReinforcementLearningZoo = "d607f57d-ee1e-4ba7-bcf2-7734c1e31854"
1212

1313
[compat]
1414
Reexport = "0.2, 1"
15-
ReinforcementLearningBase = "0.10"
16-
ReinforcementLearningCore = "0.9"
17-
ReinforcementLearningEnvironments = "0.7"
18-
ReinforcementLearningZoo = "0.6"
1915
julia = "1.6"
2016

2117
[extras]

src/ReinforcementLearningCore/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1818
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
1919
ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921"
2020
ReinforcementLearningTrajectories = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
21+
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2122
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2223
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2324
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"

src/ReinforcementLearningCore/src/policies/learners.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@ import Functors
55

66
abstract type AbstractLearner end
77

8-
(L::AbstractLearner)(env) = env |> state |> send_to_device(L) |> L |> send_to_device(env)
8+
(L::AbstractLearner)(env::AbstractEnv) = env |> state |> send_to_device(L) |> L |> send_to_device(env)
99

1010
Base.@kwdef mutable struct Approximator{M,O}
1111
model::M
1212
optimiser::O
1313
end
1414

15-
Functors.functor(x::Approximator) = (model = x.model,), y -> Approximator(y.model, x.state)
15+
Functors.functor(x::Approximator) = (model=x.model,), y -> Approximator(y.model, x.state)
1616

1717
(A::Approximator)(x) = A.model(x)
1818

0 commit comments

Comments
 (0)