@@ -31,12 +31,12 @@ $TYPEDSIGNATURES
3131Run the policy on the environment and return the total reward and a dataset of observations.
3232By default, the environment is reset before running the policy.
3333"""
34- function run_policy! (policy, env:: AbstractEnvironment )
34+ function run_policy! (policy, env:: AbstractEnvironment ; kwargs ... )
3535 total_reward = 0.0
3636 reset! (env; reset_seed= false )
3737 local labeled_dataset
3838 while ! is_terminated (env)
39- y = policy (env)
39+ y = policy (env; kwargs ... )
4040 features, state = observe (env)
4141 if @isdefined labeled_dataset
4242 push! (labeled_dataset, DataSample (; x= features, y_true= y, instance= state))
@@ -49,33 +49,35 @@ function run_policy!(policy, env::AbstractEnvironment)
4949 return total_reward, labeled_dataset
5050end
5151
52- function run_policy! (policy, envs:: Vector{<:AbstractEnvironment} )
52+ function run_policy! (policy, envs:: Vector{<:AbstractEnvironment} ; kwargs ... )
5353 E = length (envs)
5454 rewards = zeros (Float64, E)
5555 datasets = map (1 : E) do e
56- reward, dataset = run_policy! (policy, envs[e])
56+ reward, dataset = run_policy! (policy, envs[e]; kwargs ... )
5757 rewards[e] = reward
5858 return dataset
5959 end
6060 return rewards, vcat (datasets... )
6161end
6262
63- function run_policy! (policy, env:: AbstractEnvironment , episodes:: Int ; seed= get_seed (env))
63+ function run_policy! (
64+ policy, env:: AbstractEnvironment , episodes:: Int ; seed= get_seed (env), kwargs...
65+ )
6466 reset! (env; reset_seed= true , seed)
6567 total_reward = 0.0
6668 datasets = map (1 : episodes) do _i
67- reward, dataset = run_policy! (policy, env)
69+ reward, dataset = run_policy! (policy, env; kwargs ... )
6870 total_reward += reward
6971 return dataset
7072 end
7173 return total_reward / episodes, vcat (datasets... )
7274end
7375
74- function run_policy! (policy, envs:: Vector{<:AbstractEnvironment} , episodes:: Int )
76+ function run_policy! (policy, envs:: Vector{<:AbstractEnvironment} , episodes:: Int ; kwargs ... )
7577 E = length (envs)
7678 rewards = zeros (Float64, E)
7779 datasets = map (1 : E) do e
78- reward, dataset = run_policy! (policy, envs[e], episodes)
80+ reward, dataset = run_policy! (policy, envs[e], episodes; kwargs ... )
7981 rewards[e] = reward
8082 return dataset
8183 end
0 commit comments