diff --git a/agent/asyncrl/policy_output.py b/agent/asyncrl/policy_output.py index c7d55d3..001d192 100644 --- a/agent/asyncrl/policy_output.py +++ b/agent/asyncrl/policy_output.py @@ -18,11 +18,15 @@ def _sample_discrete_actions(batch_probs): List consisting of sampled actions """ action_indices = [] - - # Subtract a tiny value from probabilities in order to avoid - # "ValueError: sum(pvals[:-1]) > 1.0" in numpy.multinomial - batch_probs = batch_probs - np.finfo(np.float32).epsneg - + # Prevent having a vector which sum is not in [0, 1] + while not 0 < np.sum(batch_probs) < 1: + # Subtract a tiny value from probabilities in order to avoid + # "ValueError: sum(pvals[:-1]) > 1.0" in numpy.multinomial + batch_probs = batch_probs - np.finfo(np.float32).epsneg + # Apply abs function to keep probability values positive to avoid + # "ValueError: pvals < 0, pvals > 1 or pvals contains NaNs" in numpy.multinomial + batch_probs = np.absolute(batch_probs) + for i in range(batch_probs.shape[0]): histogram = np.random.multinomial(1, batch_probs[i]) action_indices.append(int(np.nonzero(histogram)[0]))