-
-
Notifications
You must be signed in to change notification settings - Fork 108
Open
Description
The two functions
reward(::TicTacToeEnv,::Player)
s_terminated(::TicTacToeEnv)
result in a small but needless allocation due to a type instability in call to get_tic_tac_toe_state_info()
To see this, you can use:
using ReinforcementLearning
using BenchmarkTools
env = TicTacToeEnv()
display(@benchmark reward($env))
display(@benchmark is_terminated($env))I was able to fix this problem (and save about 7% of time) with 3 small changes to TicTacToeEnv.jl. There may be other ways to fix this, but these were the simplest changes I could find.
import ReinforcementLearningEnvironments: get_tic_tac_toe_state_info
function ReinforcementLearningEnvironments.get_tic_tac_toe_state_info()
if isempty(RLEnvs.TIC_TAC_TOE_STATE_INFO)
@info "initializing tictactoe state info cache..."
t = @elapsed begin
n = 1
root = TicTacToeEnv()
RLEnvs.TIC_TAC_TOE_STATE_INFO[root] =
(index=n, is_terminated=false, winner=nothing)
walk(root) do env
if !haskey(TIC_TAC_TOE_STATE_INFO, env)
n += 1
has_empty_pos = any(view(env.board, :, :, 1))
w = if is_win(env, Player(:Cross))
Player(:Cross)
elseif is_win(env, Player(:Nought))
Player(:Nought)
else
nothing
end
RLEnvs.TIC_TAC_TOE_STATE_INFO[env] = (
index=n,
is_terminated=!(has_empty_pos && isnothing(w)),
winner=w,
)
end
end
end
@info "finished initializing tictactoe state info cache in $t seconds"
end
# CHANGE: declare type explicitly
RLEnvs.TIC_TAC_TOE_STATE_INFO::Dict{TicTacToeEnv,@NamedTuple{index::Int64, is_terminated::Bool, winner::Union{Nothing,Player}}}
end
import ReinforcementLearning: reward
function RLBase.reward(env::TicTacToeEnv, player::Player)
# CHANGE: only call get_tic_tac_toe_state_info() if necessary
if isempty(RLEnvs.TIC_TAC_TOE_STATE_INFO)
info_env = get_tic_tac_toe_state_info()[env]
else
info_env = RLEnvs.TIC_TAC_TOE_STATE_INFO[env]
end
if info_env.is_terminated
winner = info_env.winner
if isnothing(winner)
0
elseif winner === player
1
else
-1
end
else
0
end
end
import ReinforcementLearning: is_terminated
function RLBase.is_terminated(env::TicTacToeEnv)
# CHANGE: only call get_tic_tac_toe_state_info() if necessary
if isempty(RLEnvs.TIC_TAC_TOE_STATE_INFO)
return info_env = get_tic_tac_toe_state_info()[env].is_terminated
else
return info_env = RLEnvs.TIC_TAC_TOE_STATE_INFO[env].is_terminated
end
endMetadata
Metadata
Assignees
Labels
No labels