11export BoltzmannPolicy, BoltzmannMixturePolicy
22
33"""
4- BoltzmannPolicy(policy, temperature, [rng::AbstractRNG])
4+ BoltzmannPolicy(policy, temperature, [clip_threshold, rng::AbstractRNG])
55
66Policy that samples actions according to a Boltzmann distribution with the
77specified `temperature`. The unnormalized log probability of taking an action
@@ -16,6 +16,11 @@ Higher temperatures lead to an increasingly random policy, whereas a temperature
1616of zero corresponds to a deterministic policy. Q-values are computed according
1717to the underlying `policy` provided as an argument to the constructor.
1818
19+ A `clip_threshold` can be specified as a negative value, so that
20+ ``[Q(s, a) - Q_max] / T`` is clipped to a minimum of `clip_threshold`. This
21+ prevents zero-probabilty actions, and ensures that the relative odds of the
22+ most vs. least probable action is no smaller than `exp(clip_threshold)`.
23+
1924Note that wrapping an existing policy in a `BoltzmannPolicy` does not ensure
2025consistency of the state values ``V`` and Q-values ``Q`` according to the
2126Bellman equation, since this would require repeated Bellman updates to ensure
@@ -24,22 +29,25 @@ convergence.
2429@auto_hash_equals struct BoltzmannPolicy{P, R <: AbstractRNG } <: PolicySolution
2530 policy:: P
2631 temperature:: Float64
32+ clip_threshold:: Union{Float64, Nothing}
2733 rng:: R
2834end
2935
30- BoltzmannPolicy (policy:: BoltzmannPolicy , temperature, rng) =
31- BoltzmannPolicy (policy. policy, temperature, rng)
32-
33- BoltzmannPolicy (policy, temperature) =
34- BoltzmannPolicy (policy, temperature, Random. GLOBAL_RNG)
36+ BoltzmannPolicy (policy:: BoltzmannPolicy , temperature, clip_threshold, rng) =
37+ BoltzmannPolicy (policy. policy, temperature, clip_threshold, rng)
38+ BoltzmannPolicy (policy, temperature, clip_threshold = nothing ) =
39+ BoltzmannPolicy (policy, temperature, clip_threshold, Random. GLOBAL_RNG)
40+ BoltzmannPolicy (policy, temperature, rng:: AbstractRNG ) =
41+ BoltzmannPolicy (policy, temperature, nothing , rng)
3542
3643function Base. show (io:: IO , :: MIME"text/plain" , sol:: BoltzmannPolicy )
3744 indent = get (io, :indent , " " )
3845 show_struct (io, sol; indent = indent, show_fields= (:policy ,))
3946end
4047
4148Base. copy (sol:: BoltzmannPolicy ) =
42- BoltzmannPolicy (copy (sol. policy), sol. temperature, sol. rng)
49+ BoltzmannPolicy (copy (sol. policy), sol. temperature,
50+ sol. clip_threshold, sol. rng)
4351
4452get_action (sol:: BoltzmannPolicy , state:: State ) =
4553 rand_action (sol, state)
@@ -63,7 +71,7 @@ has_cached_action_values(sol::BoltzmannPolicy, state::State) =
6371function rand_action (sol:: BoltzmannPolicy , state:: State )
6472 if sol. temperature == 0
6573 # Reservoir sampling among maximal elements
66- qs = get_action_values (sol. policy , state)
74+ qs = get_action_values (sol, state)
6775 if isempty (qs) return missing end
6876 q_max = maximum (values (qs))
6977 n_max = 0
@@ -77,10 +85,18 @@ function rand_action(sol::BoltzmannPolicy, state::State)
7785 return chosen_act
7886 else
7987 # Reservoir sampling via Gumbel-max trick
88+ qs = get_action_values (sol, state)
89+ isempty (qs) && return missing
90+ if ! isnothing (sol. clip_threshold)
91+ max_score = maximum (values (qs)) / sol. temperature
92+ end
8093 chosen_act, chosen_score = missing , - Inf
81- for (act, q) in get_action_values (sol, state)
82- score = q / sol. temperature + randgumbel (sol. rng)
83- if score > chosen_score
94+ for (act, q) in qs
95+ score = q / sol. temperature
96+ if ! isnothing (sol. clip_threshold)
97+ score = max (score, max_score - sol. clip_threshold)
98+ end
99+ if score + randgumbel (sol. rng) > chosen_score
84100 chosen_act = act
85101 chosen_score = score
86102 end
@@ -100,7 +116,7 @@ function get_action_probs(sol::BoltzmannPolicy, state::State)
100116 n_max = sum (q >= q_max for q in q_values)
101117 probs = [q >= q_max ? 1.0 / n_max : 0.0 for q in q_values]
102118 else
103- probs = softmax (q_values ./ sol. temperature)
119+ probs = softmax (q_values ./ sol. temperature, sol . clip_threshold )
104120 end
105121 probs = Dict (zip (actions, probs))
106122 return probs
@@ -118,7 +134,7 @@ function get_action_prob(sol::BoltzmannPolicy, state::State, action::Term)
118134 q_act = get (action_values, action, - Inf )
119135 return q_act >= q_max ? 1.0 / n_max : 0.0
120136 else
121- probs = softmax (q_values ./ sol. temperature)
137+ probs = softmax (q_values ./ sol. temperature, sol . clip_threshold )
122138 for (a, p) in zip (actions, probs)
123139 a == action && return p
124140 end
@@ -127,42 +143,50 @@ function get_action_prob(sol::BoltzmannPolicy, state::State, action::Term)
127143end
128144
129145"""
130- BoltzmannMixturePolicy(policy, temperatures, [weights, rng::AbstractRNG])
146+ BoltzmannMixturePolicy(policy, temperatures, [weights,]
147+ [clip_threshold, rng::AbstractRNG])
131148
132149A mixture of Boltzmann policies with different `temperatures` and mixture
133150`weights`, specified as `Vector`s. If provided, `weights` must be non-negative
134151and sum to one. Otherwise a uniform mixture is assumed. Q-values are computed
135152according to the underlying `policy` provided as an argument to the constructor.
153+
154+ Similar to the [`BoltzmannPolicy`](@ref), `clip_threshold` can be used to
155+ prevent zero-probability actions.
136156"""
137157@auto_hash_equals struct BoltzmannMixturePolicy{P, R <: AbstractRNG } <: PolicySolution
138158 policy:: P
139159 temperatures:: Vector{Float64}
140160 weights:: Vector{Float64}
161+ clip_threshold:: Union{Float64, Nothing}
141162 rng:: R
142163 function BoltzmannMixturePolicy {P, R} (
143- policy:: P , temperatures, weights, rng:: R
164+ policy:: P , temperatures, weights, clip_threshold, rng:: R
144165 ) where {P, R <: AbstractRNG }
145166 @assert length (temperatures) == length (weights)
146167 @assert all (w >= 0 for w in weights)
147168 @assert isapprox (sum (weights), 1.0 )
148169 temperatures = convert (Vector{Float64}, temperatures)
149170 weights = convert (Vector{Float64}, weights)
150- return new (policy, temperatures, weights, rng)
171+ return new (policy, temperatures, weights, clip_threshold, rng)
151172 end
152173end
153174
154175function BoltzmannMixturePolicy (
155176 policy:: P ,
156177 temperatures,
157178 weights = ones (length (temperatures)) ./ length (temperatures),
179+ clip_threshold = nothing ,
158180 rng:: R = Random. GLOBAL_RNG
159181) where {P, R <: AbstractRNG }
160- return BoltzmannMixturePolicy {P, R} (policy, temperatures, weights, rng)
182+ return BoltzmannMixturePolicy {P, R} (
183+ policy, temperatures, weights, clip_threshold, rng
184+ )
161185end
162186
163187function BoltzmannMixturePolicy (policy, temperatures, rng:: AbstractRNG )
164188 weights = ones (length (temperatures)) ./ length (temperatures)
165- return BoltzmannMixturePolicy (policy, temperatures, weights, rng)
189+ return BoltzmannMixturePolicy (policy, temperatures, weights, nothing , rng)
166190end
167191
168192function Base. show (io:: IO , :: MIME"text/plain" , sol:: BoltzmannMixturePolicy )
173197
174198Base. copy (sol:: BoltzmannMixturePolicy ) =
175199 BoltzmannMixturePolicy (copy (sol. policy), copy (sol. temperatures),
176- copy (sol. weights), sol. rng)
200+ copy (sol. weights), copy (sol . clip_threshold), sol. rng)
177201
178202get_action (sol:: BoltzmannMixturePolicy , state:: State ) =
179203 rand_action (sol, state)
@@ -196,7 +220,7 @@ has_cached_action_values(sol::BoltzmannMixturePolicy, state::State) =
196220
197221function rand_action (sol:: BoltzmannMixturePolicy , state:: State )
198222 temperature = sample (sol. rng, sol. temperatures, Weights (sol. weights))
199- policy = BoltzmannPolicy (sol. policy, temperature, sol. rng)
223+ policy = BoltzmannPolicy (sol. policy, temperature, sol. clip_threshold, sol . rng)
200224 return rand_action (policy, state)
201225end
202226
@@ -216,7 +240,7 @@ function get_action_probs(sol::BoltzmannMixturePolicy, state::State)
216240 probs[i] += weight / n_max
217241 end
218242 else
219- probs .+ = softmax (q_values ./ temp) .* weight
243+ probs .+ = softmax (q_values ./ temp, sol . clip_threshold ) .* weight
220244 end
221245 end
222246 probs = Dict (zip (actions, probs))
@@ -239,7 +263,7 @@ function get_action_prob(sol::BoltzmannMixturePolicy,
239263 q_act = q_values[act_idx]
240264 act_prob += q_act >= q_max ? weight / n_max : 0.0
241265 else
242- probs = softmax (q_values ./ temp)
266+ probs = softmax (q_values ./ temp, sol . clip_threshold )
243267 act_prob += probs[act_idx] * weight
244268 end
245269 end
@@ -267,7 +291,7 @@ function get_mixture_weights(sol::BoltzmannMixturePolicy,
267291 q_act = q_values[act_idx]
268292 return q_act >= q_max ? weight / n_max : 0.0
269293 else
270- probs = softmax (q_values ./ temp)
294+ probs = softmax (q_values ./ temp, sol . clip_threshold )
271295 return probs[act_idx] * weight
272296 end
273297 end
0 commit comments