Skip to content

Commit f9731b3

Browse files
committed
Skip zero-weight components in mixture policy.
1 parent 4fc53f9 commit f9731b3

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

src/solutions/mixture_policy.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ end
6363
function get_action_probs(sol::MixturePolicy, state::State)
6464
probs = Dict{Term, Float64}()
6565
for (policy, weight) in zip(sol.policies, sol.weights)
66+
iszero(weight) && continue
6667
for (action, prob) in get_action_probs(policy, state)
6768
probs[action] = get(probs, action, 0.0) + prob * weight
6869
end
@@ -73,6 +74,7 @@ end
7374
function get_action_prob(sol::MixturePolicy, state::State, action::Term)
7475
prob = 0.0
7576
for (policy, weight) in zip(sol.policies, sol.weights)
77+
iszero(weight) && continue
7678
prob += get_action_prob(policy, state, action) * weight
7779
end
7880
return prob
@@ -98,6 +100,7 @@ end
98100
function get_mixture_weights(sol::MixturePolicy, state::State, action::Term;
99101
normalize::Bool = true)
100102
joint_probs = map(zip(sol.policies, sol.weights)) do (policy, weight)
103+
iszero(weight) && return 0.0
101104
return get_action_prob(policy, state, action) * weight
102105
end
103106
new_weights = normalize ? joint_probs ./ sum(joint_probs) : joint_probs

0 commit comments

Comments
 (0)