Skip to content

Commit d63ad9b

Browse files
committed
Support clipping in Boltzmann policies.
1 parent f9731b3 commit d63ad9b

File tree

3 files changed

+82
-35
lines changed

3 files changed

+82
-35
lines changed

src/solutions/boltzmann_policy.jl

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
export BoltzmannPolicy, BoltzmannMixturePolicy
22

33
"""
4-
BoltzmannPolicy(policy, temperature, [rng::AbstractRNG])
4+
BoltzmannPolicy(policy, temperature, [clip_threshold, rng::AbstractRNG])
55
66
Policy that samples actions according to a Boltzmann distribution with the
77
specified `temperature`. The unnormalized log probability of taking an action
@@ -16,6 +16,11 @@ Higher temperatures lead to an increasingly random policy, whereas a temperature
1616
of zero corresponds to a deterministic policy. Q-values are computed according
1717
to the underlying `policy` provided as an argument to the constructor.
1818
19+
A `clip_threshold` can be specified as a negative value, so that
20+
``[Q(s, a) - Q_max] / T`` is clipped to a minimum of `clip_threshold`. This
21+
prevents zero-probabilty actions, and ensures that the relative odds of the
22+
most vs. least probable action is no smaller than `exp(clip_threshold)`.
23+
1924
Note that wrapping an existing policy in a `BoltzmannPolicy` does not ensure
2025
consistency of the state values ``V`` and Q-values ``Q`` according to the
2126
Bellman equation, since this would require repeated Bellman updates to ensure
@@ -24,22 +29,25 @@ convergence.
2429
@auto_hash_equals struct BoltzmannPolicy{P, R <: AbstractRNG} <: PolicySolution
2530
policy::P
2631
temperature::Float64
32+
clip_threshold::Union{Float64, Nothing}
2733
rng::R
2834
end
2935

30-
BoltzmannPolicy(policy::BoltzmannPolicy, temperature, rng) =
31-
BoltzmannPolicy(policy.policy, temperature, rng)
32-
33-
BoltzmannPolicy(policy, temperature) =
34-
BoltzmannPolicy(policy, temperature, Random.GLOBAL_RNG)
36+
BoltzmannPolicy(policy::BoltzmannPolicy, temperature, clip_threshold, rng) =
37+
BoltzmannPolicy(policy.policy, temperature, clip_threshold, rng)
38+
BoltzmannPolicy(policy, temperature, clip_threshold = nothing) =
39+
BoltzmannPolicy(policy, temperature, clip_threshold, Random.GLOBAL_RNG)
40+
BoltzmannPolicy(policy, temperature, rng::AbstractRNG) =
41+
BoltzmannPolicy(policy, temperature, nothing, rng)
3542

3643
function Base.show(io::IO, ::MIME"text/plain", sol::BoltzmannPolicy)
3744
indent = get(io, :indent, "")
3845
show_struct(io, sol; indent = indent, show_fields=(:policy,))
3946
end
4047

4148
Base.copy(sol::BoltzmannPolicy) =
42-
BoltzmannPolicy(copy(sol.policy), sol.temperature, sol.rng)
49+
BoltzmannPolicy(copy(sol.policy), sol.temperature,
50+
sol.clip_threshold, sol.rng)
4351

4452
get_action(sol::BoltzmannPolicy, state::State) =
4553
rand_action(sol, state)
@@ -63,7 +71,7 @@ has_cached_action_values(sol::BoltzmannPolicy, state::State) =
6371
function rand_action(sol::BoltzmannPolicy, state::State)
6472
if sol.temperature == 0
6573
# Reservoir sampling among maximal elements
66-
qs = get_action_values(sol.policy, state)
74+
qs = get_action_values(sol, state)
6775
if isempty(qs) return missing end
6876
q_max = maximum(values(qs))
6977
n_max = 0
@@ -77,10 +85,18 @@ function rand_action(sol::BoltzmannPolicy, state::State)
7785
return chosen_act
7886
else
7987
# Reservoir sampling via Gumbel-max trick
88+
qs = get_action_values(sol, state)
89+
isempty(qs) && return missing
90+
if !isnothing(sol.clip_threshold)
91+
max_score = maximum(values(qs)) / sol.temperature
92+
end
8093
chosen_act, chosen_score = missing, -Inf
81-
for (act, q) in get_action_values(sol, state)
82-
score = q / sol.temperature + randgumbel(sol.rng)
83-
if score > chosen_score
94+
for (act, q) in qs
95+
score = q / sol.temperature
96+
if !isnothing(sol.clip_threshold)
97+
score = max(score, max_score - sol.clip_threshold)
98+
end
99+
if score + randgumbel(sol.rng) > chosen_score
84100
chosen_act = act
85101
chosen_score = score
86102
end
@@ -100,7 +116,7 @@ function get_action_probs(sol::BoltzmannPolicy, state::State)
100116
n_max = sum(q >= q_max for q in q_values)
101117
probs = [q >= q_max ? 1.0 / n_max : 0.0 for q in q_values]
102118
else
103-
probs = softmax(q_values ./ sol.temperature)
119+
probs = softmax(q_values ./ sol.temperature, sol.clip_threshold)
104120
end
105121
probs = Dict(zip(actions, probs))
106122
return probs
@@ -118,7 +134,7 @@ function get_action_prob(sol::BoltzmannPolicy, state::State, action::Term)
118134
q_act = get(action_values, action, -Inf)
119135
return q_act >= q_max ? 1.0 / n_max : 0.0
120136
else
121-
probs = softmax(q_values ./ sol.temperature)
137+
probs = softmax(q_values ./ sol.temperature, sol.clip_threshold)
122138
for (a, p) in zip(actions, probs)
123139
a == action && return p
124140
end
@@ -127,42 +143,50 @@ function get_action_prob(sol::BoltzmannPolicy, state::State, action::Term)
127143
end
128144

129145
"""
130-
BoltzmannMixturePolicy(policy, temperatures, [weights, rng::AbstractRNG])
146+
BoltzmannMixturePolicy(policy, temperatures, [weights,]
147+
[clip_threshold, rng::AbstractRNG])
131148
132149
A mixture of Boltzmann policies with different `temperatures` and mixture
133150
`weights`, specified as `Vector`s. If provided, `weights` must be non-negative
134151
and sum to one. Otherwise a uniform mixture is assumed. Q-values are computed
135152
according to the underlying `policy` provided as an argument to the constructor.
153+
154+
Similar to the [`BoltzmannPolicy`](@ref), `clip_threshold` can be used to
155+
prevent zero-probability actions.
136156
"""
137157
@auto_hash_equals struct BoltzmannMixturePolicy{P, R <: AbstractRNG} <: PolicySolution
138158
policy::P
139159
temperatures::Vector{Float64}
140160
weights::Vector{Float64}
161+
clip_threshold::Union{Float64, Nothing}
141162
rng::R
142163
function BoltzmannMixturePolicy{P, R}(
143-
policy::P, temperatures, weights, rng::R
164+
policy::P, temperatures, weights, clip_threshold, rng::R
144165
) where {P, R <: AbstractRNG}
145166
@assert length(temperatures) == length(weights)
146167
@assert all(w >= 0 for w in weights)
147168
@assert isapprox(sum(weights), 1.0)
148169
temperatures = convert(Vector{Float64}, temperatures)
149170
weights = convert(Vector{Float64}, weights)
150-
return new(policy, temperatures, weights, rng)
171+
return new(policy, temperatures, weights, clip_threshold, rng)
151172
end
152173
end
153174

154175
function BoltzmannMixturePolicy(
155176
policy::P,
156177
temperatures,
157178
weights = ones(length(temperatures)) ./ length(temperatures),
179+
clip_threshold = nothing,
158180
rng::R = Random.GLOBAL_RNG
159181
) where {P, R <: AbstractRNG}
160-
return BoltzmannMixturePolicy{P, R}(policy, temperatures, weights, rng)
182+
return BoltzmannMixturePolicy{P, R}(
183+
policy, temperatures, weights, clip_threshold, rng
184+
)
161185
end
162186

163187
function BoltzmannMixturePolicy(policy, temperatures, rng::AbstractRNG)
164188
weights = ones(length(temperatures)) ./ length(temperatures)
165-
return BoltzmannMixturePolicy(policy, temperatures, weights, rng)
189+
return BoltzmannMixturePolicy(policy, temperatures, weights, nothing, rng)
166190
end
167191

168192
function Base.show(io::IO, ::MIME"text/plain", sol::BoltzmannMixturePolicy)
@@ -173,7 +197,7 @@ end
173197

174198
Base.copy(sol::BoltzmannMixturePolicy) =
175199
BoltzmannMixturePolicy(copy(sol.policy), copy(sol.temperatures),
176-
copy(sol.weights), sol.rng)
200+
copy(sol.weights), copy(sol.clip_threshold), sol.rng)
177201

178202
get_action(sol::BoltzmannMixturePolicy, state::State) =
179203
rand_action(sol, state)
@@ -196,7 +220,7 @@ has_cached_action_values(sol::BoltzmannMixturePolicy, state::State) =
196220

197221
function rand_action(sol::BoltzmannMixturePolicy, state::State)
198222
temperature = sample(sol.rng, sol.temperatures, Weights(sol.weights))
199-
policy = BoltzmannPolicy(sol.policy, temperature, sol.rng)
223+
policy = BoltzmannPolicy(sol.policy, temperature, sol.clip_threshold, sol.rng)
200224
return rand_action(policy, state)
201225
end
202226

@@ -216,7 +240,7 @@ function get_action_probs(sol::BoltzmannMixturePolicy, state::State)
216240
probs[i] += weight / n_max
217241
end
218242
else
219-
probs .+= softmax(q_values ./ temp) .* weight
243+
probs .+= softmax(q_values ./ temp, sol.clip_threshold) .* weight
220244
end
221245
end
222246
probs = Dict(zip(actions, probs))
@@ -239,7 +263,7 @@ function get_action_prob(sol::BoltzmannMixturePolicy,
239263
q_act = q_values[act_idx]
240264
act_prob += q_act >= q_max ? weight / n_max : 0.0
241265
else
242-
probs = softmax(q_values ./ temp)
266+
probs = softmax(q_values ./ temp, sol.clip_threshold)
243267
act_prob += probs[act_idx] * weight
244268
end
245269
end
@@ -267,7 +291,7 @@ function get_mixture_weights(sol::BoltzmannMixturePolicy,
267291
q_act = q_values[act_idx]
268292
return q_act >= q_max ? weight / n_max : 0.0
269293
else
270-
probs = softmax(q_values ./ temp)
294+
probs = softmax(q_values ./ temp, sol.clip_threshold)
271295
return probs[act_idx] * weight
272296
end
273297
end

src/utils.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,13 @@ function auto_equals_expr(type, fields)
6666
end
6767

6868
"Convert vector of unnormalized scores to probabiities."
69-
function softmax(scores)
69+
function softmax(scores, min_rel_score=nothing)
7070
if isempty(scores) return Float64[] end
71-
ws = exp.(scores .- maximum(scores))
71+
rel_scores = scores .- maximum(scores)
72+
if !isnothing(min_rel_score)
73+
rel_scores = max.(rel_scores, min_rel_score)
74+
end
75+
ws = exp.(rel_scores)
7276
z = sum(ws)
7377
return isnan(z) ? ones(length(scores)) ./ length(scores) : ws ./ z
7478
end

test/solutions.jl

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -321,17 +321,26 @@ act_prob = probs[pddl"(pick-up a)"]
321321
@test get_action_prob(sol, bw_state, pddl"(pick-up a)") act_prob
322322
@test get_action_prob(sol, bw_state, pddl"(pick-up z)") == 0.0
323323

324-
sol = TabularPolicy()
325-
sol.V[hash(bw_state)] = bw_init_v
326-
sol.Q[hash(bw_state)] = copy(bw_init_q)
327-
sol.Q[hash(bw_state)][pddl"(pick-up a)"] = -Inf
328-
sol = BoltzmannPolicy(sol, 0.0)
324+
v_sol = TabularPolicy()
325+
v_sol.V[hash(bw_state)] = bw_init_v
326+
v_sol.Q[hash(bw_state)] = copy(bw_init_q)
327+
v_sol.Q[hash(bw_state)][pddl"(pick-up a)"] = -Inf
328+
329+
sol = BoltzmannPolicy(v_sol, 0.0)
329330
probs = Dict(a => a == pddl"(pick-up a)" ? 0.0 : 0.5 for a in bw_init_actions)
330331
@test all(probs[a] p for (a, p) in get_action_probs(sol, bw_state))
331332
@test get_action_prob(sol, bw_state, pddl"(pick-up a)") == 0.0
332333
@test get_action_prob(sol, bw_state, pddl"(pick-up b)") == 0.5
333334
@test get_action_prob(sol, bw_state, pddl"(pick-up c)") == 0.5
334335

336+
sol = BoltzmannPolicy(v_sol, 2.0, -8.0)
337+
prob_a = get_action_prob(sol, bw_state, pddl"(pick-up a)")
338+
prob_b = get_action_prob(sol, bw_state, pddl"(pick-up b)")
339+
prob_c = get_action_prob(sol, bw_state, pddl"(pick-up c)")
340+
@test log(prob_a) - log(prob_b) -8.0
341+
@test log(prob_a) - log(prob_c) -8.0
342+
@test log(prob_b) - log(prob_c) 0.0
343+
335344
@test has_values(sol) == true
336345

337346
@test copy(sol) == sol
@@ -373,11 +382,12 @@ new_weights = [0.4 * probs_1[act], 0.6 * probs_2[act]]
373382
new_weights = new_weights ./ sum(new_weights)
374383
@test get_mixture_weights(sol, bw_state, act) new_weights
375384

376-
sol = TabularPolicy()
377-
sol.V[hash(bw_state)] = bw_init_v
378-
sol.Q[hash(bw_state)] = copy(bw_init_q)
379-
sol.Q[hash(bw_state)][pddl"(pick-up a)"] = -6.0 - log(2)
380-
sol = BoltzmannMixturePolicy(sol, [0.0, 1.0])
385+
v_sol = TabularPolicy()
386+
v_sol.V[hash(bw_state)] = bw_init_v
387+
v_sol.Q[hash(bw_state)] = copy(bw_init_q)
388+
389+
sol = BoltzmannMixturePolicy(v_sol, [0.0, 1.0])
390+
v_sol.Q[hash(bw_state)][pddl"(pick-up a)"] = -6.0 - log(2)
381391
probs = Dict(a => a == pddl"(pick-up a)" ? 0.1 : 0.45 for a in bw_init_actions)
382392
@test all(probs[a] p for (a, p) in get_action_probs(sol, bw_state))
383393
@test get_action_prob(sol, bw_state, pddl"(pick-up a)") 0.1
@@ -387,6 +397,15 @@ new_weights = [0.5, 0.4] .* get_mixture_weights(sol)
387397
new_weights = new_weights ./ sum(new_weights)
388398
@test get_mixture_weights(sol, bw_state, pddl"(pick-up b)") new_weights
389399

400+
sol = BoltzmannMixturePolicy(v_sol, [1.0, 2.0], [0.25, 0.75], -8.0)
401+
v_sol.Q[hash(bw_state)][pddl"(pick-up a)"] = -Inf
402+
prob_a = get_action_prob(sol, bw_state, pddl"(pick-up a)")
403+
prob_b = get_action_prob(sol, bw_state, pddl"(pick-up b)")
404+
prob_c = get_action_prob(sol, bw_state, pddl"(pick-up c)")
405+
@test log(prob_a) - log(prob_b) -8.0
406+
@test log(prob_a) - log(prob_c) -8.0
407+
@test log(prob_b) - log(prob_c) 0.0
408+
390409
@test has_values(sol) == true
391410

392411
@test copy(sol) == sol

0 commit comments

Comments
 (0)