-
Notifications
You must be signed in to change notification settings - Fork 4
Features needed for routing env in qiskit-gym #40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,19 +16,114 @@ use rand::{prelude::Distribution, Rng}; | |
| use crate::nn::modules::Sequential; | ||
| use crate::nn::layers::EmbeddingBag; | ||
|
|
||
| #[derive(Clone, Copy, Debug, PartialEq, Eq)] | ||
| pub enum ActionMode { | ||
| // One logit per discrete action. | ||
| Categorical, | ||
| // One logit per binary factor; expanded to categorical logits on demand. | ||
| FactorizedBernoulli, | ||
| } | ||
|
|
||
| impl ActionMode { | ||
| fn from_name(name: &str) -> Self { | ||
| match name.trim().to_ascii_lowercase().as_str() { | ||
| "factorized_bernoulli" => Self::FactorizedBernoulli, | ||
| _ => Self::Categorical, | ||
| } | ||
| } | ||
| } | ||
|
|
||
| #[derive(Clone)] | ||
| pub struct Policy { | ||
| embeddings: Box<EmbeddingBag>, | ||
| common: Box<Sequential>, | ||
| action_net: Box<Sequential>, | ||
| value_net: Box<Sequential>, | ||
| obs_perms: Vec<Vec<usize>>, | ||
| act_perms: Vec<Vec<usize>> | ||
| act_perms: Vec<Vec<usize>>, | ||
| action_mode: ActionMode, | ||
| num_action_factors: usize, | ||
| num_actions: usize, | ||
| } | ||
|
|
||
| impl Policy { | ||
| pub fn new(embeddings: Box<EmbeddingBag>, common: Box<Sequential>, action_net: Box<Sequential>, value_net: Box<Sequential>, obs_perms: Vec<Vec<usize>>, act_perms: Vec<Vec<usize>>) -> Self { | ||
| Self { embeddings: embeddings, common, action_net, value_net, obs_perms, act_perms } | ||
| let inferred_num_actions = act_perms.first().map(|p| p.len()).unwrap_or(0); | ||
| Self { | ||
| embeddings, | ||
| common, | ||
| action_net, | ||
| value_net, | ||
| obs_perms, | ||
| act_perms, | ||
| action_mode: ActionMode::Categorical, | ||
| num_action_factors: 0, | ||
| num_actions: inferred_num_actions, | ||
| } | ||
| } | ||
|
|
||
| pub fn new_with_action_mode( | ||
| embeddings: Box<EmbeddingBag>, | ||
| common: Box<Sequential>, | ||
| action_net: Box<Sequential>, | ||
| value_net: Box<Sequential>, | ||
| obs_perms: Vec<Vec<usize>>, | ||
| act_perms: Vec<Vec<usize>>, | ||
| action_mode: String, | ||
| num_action_factors: usize, | ||
| num_actions: usize, | ||
| ) -> Self { | ||
| let mut out = Self::new(embeddings, common, action_net, value_net, obs_perms, act_perms); | ||
| out.action_mode = ActionMode::from_name(&action_mode); | ||
| out.num_action_factors = num_action_factors; | ||
| out.num_actions = if num_actions > 0 { | ||
| num_actions | ||
| } else { | ||
| out.num_actions | ||
| }; | ||
| out | ||
| } | ||
|
|
||
| fn effective_num_actions(&self) -> usize { | ||
| if self.num_actions > 0 { | ||
| return self.num_actions; | ||
| } | ||
| if let Some(first_perm) = self.act_perms.first() { | ||
| if !first_perm.is_empty() { | ||
| return first_perm.len(); | ||
| } | ||
| } | ||
| if self.action_mode == ActionMode::FactorizedBernoulli && self.num_action_factors > 0 { | ||
| return 1usize.checked_shl(self.num_action_factors as u32).unwrap_or(0); | ||
| } | ||
| 0 | ||
| } | ||
|
|
||
| fn expand_factorized_logits(&self, factor_logits: &[f32]) -> Vec<f32> { | ||
| // Convert per-factor logits into per-action logits by summing logits of active bits. | ||
| let num_factors = if self.num_action_factors > 0 { | ||
| self.num_action_factors | ||
| } else { | ||
| factor_logits.len() | ||
| }; | ||
| if num_factors == 0 || factor_logits.len() < num_factors { | ||
| return factor_logits.to_vec(); | ||
| } | ||
| let num_actions = self.effective_num_actions(); | ||
| if num_actions == 0 { | ||
| return factor_logits.to_vec(); | ||
| } | ||
| let mut expanded = vec![0.0f32; num_actions]; | ||
| for action in 0..num_actions { | ||
| let mut logit = 0.0f32; | ||
| for bit in 0..num_factors { | ||
| if ((action >> bit) & 1usize) == 1usize { | ||
| logit += factor_logits[bit]; | ||
| } | ||
| } | ||
| expanded[action] = logit; | ||
| } | ||
| expanded | ||
| } | ||
|
|
||
| pub fn predict(&self, obs: Vec<usize>, masks: Vec<bool>) -> (Vec<f32>, f32) { | ||
|
|
@@ -89,11 +184,19 @@ impl Policy { | |
| let value = self.value_net.forward(common_out.clone()).sum(); // This only has one element | ||
|
|
||
| // Forward of the action net | ||
| let mut action_logits = self.action_net.forward(common_out).data.as_vec().to_owned(); | ||
| let raw_action_logits = self.action_net.forward(common_out).data.as_vec().to_owned(); | ||
| let mut action_logits = match self.action_mode { | ||
| ActionMode::Categorical => raw_action_logits, | ||
| ActionMode::FactorizedBernoulli => self.expand_factorized_logits(&raw_action_logits), | ||
| }; | ||
|
|
||
| // Permute logits according to the corresponding act_perm | ||
| if let Some(pi) = n_perm { | ||
| action_logits = self.act_perms[pi].iter().map(|&v| action_logits[v]).collect(); | ||
| if let Some(act_perm) = self.act_perms.get(pi) { | ||
| if act_perm.len() == action_logits.len() { | ||
| action_logits = act_perm.iter().map(|&v| action_logits[v]).collect(); | ||
| } | ||
| } | ||
|
Comment on lines
+195
to
+199
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The condition |
||
| } | ||
|
|
||
| (action_logits, value) | ||
|
|
@@ -103,7 +206,7 @@ impl Policy { | |
| if self.obs_perms.len() == 0 {return self.predict(obs, masks);}; | ||
|
|
||
| // Forward of the action net for each perm | ||
| let mut action_logits = vec![0.0f32; self.act_perms[0].len()]; | ||
| let mut action_logits = vec![0.0f32; self.effective_num_actions()]; | ||
| let mut value = 0.0f32; | ||
|
|
||
| for pi in 0..self.obs_perms.len() { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In
expand_factorized_logits, ifnum_factorsis non-zero butfactor_logits.len()is less thannum_factors, or ifeffective_num_actions()returns 0, the function currently returnsfactor_logits.to_vec(). This behavior might mask an underlying configuration error or an invalid state. It would be more robust to explicitly handle these as error conditions, perhaps by logging a warning or raising an error, to prevent silent misbehavior in the reinforcement learning process.