Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions rust/src/collector/ppo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,16 @@ impl PPOCollector {
}

impl PPOCollector {
fn get_step_data(
fn get_action_data(
&self,
env: &dyn Env,
policy: &Policy,
) -> (Vec<usize>, Vec<f32>, usize, f32, f32, Option<usize>) {
) -> (Vec<usize>, Vec<f32>, usize, f32, Option<usize>) {
let obs = env.observe(); // Vec<f32> or whatever your Env returns
let masks = env.masks();
let reward = env.reward();
let (logits, value, perm_idx) = policy.forward_with_perm(obs.clone(), masks);
let action = sample_from_logits(&logits);
(obs, logits, action, value, reward, perm_idx)
(obs, logits, action, value, perm_idx)
}

fn single_collect(
Expand All @@ -67,20 +66,31 @@ impl PPOCollector {
let mut perms = Vec::new();

loop {
let (obs, log_prob, act, val, rew, perm_idx) = self.get_step_data(&*env, policy);
if env.is_final() { break; }
let (obs, log_prob, act, val, perm_idx) = self.get_action_data(&*env, policy);
// Step first, then read reward so collected transitions follow env step semantics.
env.step(act);
let rew = env.reward();
obss.push(obs);
log_probs.push(log_prob);
vals.push(val);
rews.push(rew);
acts.push(act);
perms.push(perm_idx);

if env.is_final() { break; }
env.step(act);
}

// compute GAE advs/rets
let n = rews.len();
if n == 0 {
return CollectedData::new(
obss,
log_probs,
perms,
vals,
rews,
acts,
);
}
let mut advs = vec![0.0; n];
let mut rets = vec![0.0; n];
advs[n-1] = rews[n-1] - vals[n-1];
Expand Down Expand Up @@ -177,7 +187,7 @@ mod tests {
let collector = PPOCollector::new(1, 0.9, 0.95, 1);

let data = collector.collect(&env, &policy).unwrap();
assert_eq!(data.obs.len(), 2);
assert_eq!(data.obs.len(), 1);
assert!(data.additional_data.contains_key("rets"));
}
}
113 changes: 108 additions & 5 deletions rust/src/nn/policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,114 @@ use rand::{prelude::Distribution, Rng};
use crate::nn::modules::Sequential;
use crate::nn::layers::EmbeddingBag;

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ActionMode {
// One logit per discrete action.
Categorical,
// One logit per binary factor; expanded to categorical logits on demand.
FactorizedBernoulli,
}

impl ActionMode {
fn from_name(name: &str) -> Self {
match name.trim().to_ascii_lowercase().as_str() {
"factorized_bernoulli" => Self::FactorizedBernoulli,
_ => Self::Categorical,
}
}
}

#[derive(Clone)]
pub struct Policy {
embeddings: Box<EmbeddingBag>,
common: Box<Sequential>,
action_net: Box<Sequential>,
value_net: Box<Sequential>,
obs_perms: Vec<Vec<usize>>,
act_perms: Vec<Vec<usize>>
act_perms: Vec<Vec<usize>>,
action_mode: ActionMode,
num_action_factors: usize,
num_actions: usize,
}

impl Policy {
pub fn new(embeddings: Box<EmbeddingBag>, common: Box<Sequential>, action_net: Box<Sequential>, value_net: Box<Sequential>, obs_perms: Vec<Vec<usize>>, act_perms: Vec<Vec<usize>>) -> Self {
Self { embeddings: embeddings, common, action_net, value_net, obs_perms, act_perms }
let inferred_num_actions = act_perms.first().map(|p| p.len()).unwrap_or(0);
Self {
embeddings,
common,
action_net,
value_net,
obs_perms,
act_perms,
action_mode: ActionMode::Categorical,
num_action_factors: 0,
num_actions: inferred_num_actions,
}
}

pub fn new_with_action_mode(
embeddings: Box<EmbeddingBag>,
common: Box<Sequential>,
action_net: Box<Sequential>,
value_net: Box<Sequential>,
obs_perms: Vec<Vec<usize>>,
act_perms: Vec<Vec<usize>>,
action_mode: String,
num_action_factors: usize,
num_actions: usize,
) -> Self {
let mut out = Self::new(embeddings, common, action_net, value_net, obs_perms, act_perms);
out.action_mode = ActionMode::from_name(&action_mode);
out.num_action_factors = num_action_factors;
out.num_actions = if num_actions > 0 {
num_actions
} else {
out.num_actions
};
out
}

fn effective_num_actions(&self) -> usize {
if self.num_actions > 0 {
return self.num_actions;
}
if let Some(first_perm) = self.act_perms.first() {
if !first_perm.is_empty() {
return first_perm.len();
}
}
if self.action_mode == ActionMode::FactorizedBernoulli && self.num_action_factors > 0 {
return 1usize.checked_shl(self.num_action_factors as u32).unwrap_or(0);
}
0
}

fn expand_factorized_logits(&self, factor_logits: &[f32]) -> Vec<f32> {
// Convert per-factor logits into per-action logits by summing logits of active bits.
let num_factors = if self.num_action_factors > 0 {
self.num_action_factors
} else {
factor_logits.len()
};
if num_factors == 0 || factor_logits.len() < num_factors {
return factor_logits.to_vec();
}
let num_actions = self.effective_num_actions();
if num_actions == 0 {
return factor_logits.to_vec();
}
Comment on lines +109 to +115

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In expand_factorized_logits, if num_factors is non-zero but factor_logits.len() is less than num_factors, or if effective_num_actions() returns 0, the function currently returns factor_logits.to_vec(). This behavior might mask an underlying configuration error or an invalid state. It would be more robust to explicitly handle these as error conditions, perhaps by logging a warning or raising an error, to prevent silent misbehavior in the reinforcement learning process.

let mut expanded = vec![0.0f32; num_actions];
for action in 0..num_actions {
let mut logit = 0.0f32;
for bit in 0..num_factors {
if ((action >> bit) & 1usize) == 1usize {
logit += factor_logits[bit];
}
}
expanded[action] = logit;
}
expanded
}

pub fn predict(&self, obs: Vec<usize>, masks: Vec<bool>) -> (Vec<f32>, f32) {
Expand Down Expand Up @@ -89,11 +184,19 @@ impl Policy {
let value = self.value_net.forward(common_out.clone()).sum(); // This only has one element

// Forward of the action net
let mut action_logits = self.action_net.forward(common_out).data.as_vec().to_owned();
let raw_action_logits = self.action_net.forward(common_out).data.as_vec().to_owned();
let mut action_logits = match self.action_mode {
ActionMode::Categorical => raw_action_logits,
ActionMode::FactorizedBernoulli => self.expand_factorized_logits(&raw_action_logits),
};

// Permute logits according to the corresponding act_perm
if let Some(pi) = n_perm {
action_logits = self.act_perms[pi].iter().map(|&v| action_logits[v]).collect();
if let Some(act_perm) = self.act_perms.get(pi) {
if act_perm.len() == action_logits.len() {
action_logits = act_perm.iter().map(|&v| action_logits[v]).collect();
}
}
Comment on lines +195 to +199

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition if act_perm.len() == action_logits.len() is a good check. However, if the lengths do not match, the permutation is silently skipped. This could lead to unexpected behavior if the permutation was intended to be applied but couldn't due to a mismatch. Consider logging a warning or raising an error in this else branch to make such inconsistencies explicit, aiding in debugging and preventing silent failures.

}

(action_logits, value)
Expand All @@ -103,7 +206,7 @@ impl Policy {
if self.obs_perms.len() == 0 {return self.predict(obs, masks);};

// Forward of the action net for each perm
let mut action_logits = vec![0.0f32; self.act_perms[0].len()];
let mut action_logits = vec![0.0f32; self.effective_num_actions()];
let mut value = 0.0f32;

for pi in 0..self.obs_perms.len() {
Expand Down
35 changes: 33 additions & 2 deletions rust/src/python_interface/policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,39 @@ pub struct PyPolicy {
#[pymethods]
impl PyPolicy {
#[new]
pub fn new(embeddings: PyEmbeddingBag, common: PySequential, action_net: PySequential, value_net: PySequential, obs_perms: Vec<Vec<usize>>, act_perms: Vec<Vec<usize>>) -> Self {
let policy = Box::new(Policy::new(embeddings.embedding, common.seq, action_net.seq, value_net.seq, obs_perms, act_perms));
#[pyo3(signature = (
embeddings,
common,
action_net,
value_net,
obs_perms,
act_perms,
action_mode = "categorical",
num_action_factors = 0,
num_actions = 0
))]
pub fn new(
embeddings: PyEmbeddingBag,
common: PySequential,
action_net: PySequential,
value_net: PySequential,
obs_perms: Vec<Vec<usize>>,
act_perms: Vec<Vec<usize>>,
action_mode: &str,
num_action_factors: usize,
num_actions: usize,
) -> Self {
let policy = Box::new(Policy::new_with_action_mode(
embeddings.embedding,
common.seq,
action_net.seq,
value_net.seq,
obs_perms,
act_perms,
action_mode.to_string(),
num_action_factors,
num_actions,
));
PyPolicy { policy }
}

Expand Down
12 changes: 11 additions & 1 deletion src/twisterl/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,17 @@

# Learning

LEARNING_CONFIG = {"diff_threshold": 0.85, "diff_metric": "ppo_deterministic"}
LEARNING_CONFIG = {
"diff_threshold": 0.85,
# Lower hysteresis threshold used to avoid rapid on/off difficulty toggling.
"threshold_min": 0.85,
"diff_max": 256,
"diff_step": 1,
# While difficulty <= warmup, keep +1 increments regardless of diff_step.
"warmup": 0,
"final_diff_is_none": False,
"diff_metric": "ppo_deterministic",
}


# Logging and checkpoints
Expand Down
Loading
Loading