Skip to content

Commit 3637aa9

Browse files
Add permutations as twists in training (#27)
* Add permutations as twists in training * Fix lint * Add documentation on twists * some fixes
1 parent 88714e3 commit 3637aa9

File tree

13 files changed

+335
-55
lines changed

13 files changed

+335
-55
lines changed

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,16 @@ The `examples/grid_world` custom environment example [here](examples/grid_world)
100100

101101
Refer to [grid_world](examples/grid_world) for a complete working example.
102102

103+
## Documentation
104+
105+
- [Permutation twists in environments](docs/twists.md)
106+
103107
## 🚀 Key Features
104108
- **High-Performance Core**: RL episode loop implemented in Rust for faster training and inference
105109
- **Inference-Ready**: Easy compilation and bundling of models with environments into portable binaries for inference
106110
- **Modular Design**: Support for multiple algorithms (PPO, AlphaZero) with interchangeable training and inference
107111
- **Language Interoperability**: Core in Rust with Python interface
112+
- **Symmetry-Aware Training via Twists**: Environments can expose observation/action permutations (“twists”) so policies automatically exploit device or puzzle symmetries for faster learning.
108113

109114

110115
## 🏗️ Current State (PoC)
@@ -165,4 +170,4 @@ This project is currently in PoC stage. While functional, it's under active deve
165170

166171
## 📜 License
167172

168-
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
173+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

docs/twists.md

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Twists (Permutation Symmetries) in twisteRL
2+
3+
Twists are twisteRL's way to describe exact permutation symmetries that exist inside an environment.
4+
Instead of training a policy to rediscover that symmetries exist (for example, that swapping qubits
5+
across a symmetric coupling map produces an equivalent observation/action space), the environment
6+
hands the policy explicit permutations that it can use for data augmentation or symmetry-aware heads.
7+
By repeatedly doing and undoing these permutations you also reduce the chance of deadlocks and gain a
8+
lightweight form of regularization because the agent sees equivalent states under many orderings.
9+
10+
## Where Twists Are Used
11+
- Every environment implements the `twisterl::rl::env::Env` trait. The trait includes a `twists`
12+
method that returns `(Vec<Vec<usize>>, Vec<Vec<usize>>)` representing valid permutations on the
13+
flattened observation array and matching permutations on the discrete action space
14+
(`rust/src/rl/env.rs:33`).
15+
- When an environment is instantiated from Python via `prepare_algorithm`, twisteRL immediately calls
16+
`env.twists()` and forwards the returned permutations to the policy constructor
17+
(`src/twisterl/utils.py:120`). The policy can then symmetrize logits, average values, or augment
18+
rollouts without extra environment queries.
19+
20+
## Data Contract
21+
1. **Observation permutations (`obs_perms`)** are expressed in the same flattened index space
22+
produced by the environment’s `observe()` method. Each permutation covers every index exactly once.
23+
2. **Action permutations (`act_perms`)** must use the same ordering as `obs_perms`. TwisteRL
24+
assumes `act_perms[i]` describes how to remap actions when `obs_perms[i]` is applied.
25+
3. The length of the two permutation lists must match (`len(obs_perms) == len(act_perms)`), and the
26+
first permutation should usually be the identity so policies have a canonical ordering to fall back to.
27+
28+
## Implementing Twists in Rust Environments
29+
1. **Compute permutations once** when the environment is constructed. Store the resulting vectors on
30+
the struct so you can reuse them without recomputing each step.
31+
2. **Return cached permutations** from the `twists` method by cloning or otherwise referencing the
32+
stored vectors. This keeps the call cheap even when policies request twists frequently.
33+
3. **Gate toggles through config**. Consider exposing a `use_perms` or `add_perms` flag so users can
34+
disable symmetries if they want to benchmark raw performance or compare against non-symmetric runs.
35+
36+
### Tips for new envs
37+
- If your observation is multi-dimensional, decide on a consistent flattening order and reuse it in
38+
`observe()`, `obs_shape()`, and permutation computation.
39+
- Keep permutations short: only add a symmetry when it actually preserves the transition dynamics;
40+
incorrect permutations can break training stability.
41+
- Store permutations on the struct instead of recomputing them each `twists()` call to avoid extra
42+
allocations during training.
43+
44+
## Implementing Twists in Python Environments
45+
Python environments exposed through `PyEnv` can mirror the same pattern:
46+
47+
1. **Detect graph/device symmetries** using domain-specific tooling. Capture any permutation that
48+
leaves the transition structure unchanged.
49+
2. **Sample a permutation for every observation** if you want trajectories to naturally explore each
50+
orbit; this mimics the way many structured environments randomize qubit or tile order.
51+
3. **Expose action permutations** through the PyO3 wrapper so the policy receives matching
52+
permutations. When porting a Python env to Rust, copy the action/observation permutation lists into
53+
the Rust struct and return them from `twists()`.
54+
55+
## Verifying Your Twists
56+
1. Call `env.twists()` from Python and check that each permutation is a rearrangement of
57+
`range(len(observe()))` and `range(num_actions())`.
58+
2. Run a short training job with and without permutations enabled. If permutations are correct you
59+
should see either faster convergence or identical performance; regressions usually mean the
60+
action-and-observation permutations are misaligned.
61+
3. For debugging, temporarily limit the permutation list to `[identity]` and re-enable additional
62+
symmetries one at a time.
63+
64+
By explicitly documenting and exposing twists, twisteRL policies gain symmetry awareness for free,
65+
leading to higher data efficiency on structured problems such as puzzle solvers and quantum circuit
66+
optimization.

rust/src/collector/az.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ impl AZCollector {
6161
let mut probs: Vec<Vec<f32>> = vec![];
6262
let mut vals: Vec<f32> = vec![];
6363
let mut total_vals: Vec<f32> = vec![];
64-
6564
let mut total_val = 0.0;
6665

6766
// Loop until a final state
@@ -93,9 +92,12 @@ impl AZCollector {
9392
// Post process rewards
9493
let remaining_vals: Vec<f32> = total_vals.iter().map(|&v| total_val - v).collect();
9594

95+
let perms: Vec<Option<usize>> = vec![None; obs.len()];
96+
9697
let mut data = CollectedData::new(
9798
obs,
9899
probs,
100+
perms,
99101
vec![],
100102
vec![],
101103
vec![],
@@ -182,4 +184,3 @@ mod tests {
182184
assert!(data.additional_data.contains_key("remaining_values"));
183185
}
184186
}
185-

rust/src/collector/collector.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ pub struct CollectedData {
2424
pub obs: Vec<Vec<usize>>,
2525
/// Logits (action probabilities) at each timestep
2626
pub logits: Vec<Vec<f32>>,
27+
/// Optional permutation index used at each timestep (-1 if none)
28+
pub perms: Vec<Option<usize>>,
2729
/// Value estimates at each timestep
2830
pub values: Vec<f32>,
2931
/// Rewards received at each timestep
@@ -48,13 +50,15 @@ impl CollectedData {
4850
pub fn new(
4951
obs: Vec<Vec<usize>>,
5052
logits: Vec<Vec<f32>>,
53+
perms: Vec<Option<usize>>,
5154
values: Vec<f32>,
5255
rewards: Vec<f32>,
5356
actions: Vec<usize>,
5457
) -> Self {
5558
CollectedData {
5659
obs,
5760
logits,
61+
perms,
5862
values,
5963
rewards,
6064
actions,
@@ -67,6 +71,7 @@ impl CollectedData {
6771
// Append observations and logits (2D vectors)
6872
self.obs.extend(other.obs.iter().cloned());
6973
self.logits.extend(other.logits.iter().cloned());
74+
self.perms.extend(other.perms.iter().cloned());
7075

7176
// Append 1D vectors
7277
self.values.extend(&other.values);
@@ -98,6 +103,7 @@ mod tests {
98103
let d1 = CollectedData::new(
99104
vec![vec![0]],
100105
vec![vec![0.1]],
106+
vec![Some(0)],
101107
vec![0.2],
102108
vec![0.3],
103109
vec![1],
@@ -106,6 +112,7 @@ mod tests {
106112
let d2 = CollectedData::new(
107113
vec![vec![1]],
108114
vec![vec![0.4]],
115+
vec![None],
109116
vec![0.5],
110117
vec![0.6],
111118
vec![0],

rust/src/collector/ppo.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@ impl PPOCollector {
4242
&self,
4343
env: &dyn Env,
4444
policy: &Policy,
45-
) -> (Vec<usize>, Vec<f32>, usize, f32, f32) {
45+
) -> (Vec<usize>, Vec<f32>, usize, f32, f32, Option<usize>) {
4646
let obs = env.observe(); // Vec<f32> or whatever your Env returns
4747
let masks = env.masks();
4848
let reward = env.reward();
49-
let (logits, value) = policy.forward(obs.clone(), masks);
49+
let (logits, value, perm_idx) = policy.forward_with_perm(obs.clone(), masks);
5050
let action = sample_from_logits(&logits);
51-
(obs, logits, action, value, reward)
51+
(obs, logits, action, value, reward, perm_idx)
5252
}
5353

5454
fn single_collect(
@@ -64,14 +64,16 @@ impl PPOCollector {
6464
let mut vals = Vec::new();
6565
let mut rews = Vec::new();
6666
let mut acts = Vec::new();
67+
let mut perms = Vec::new();
6768

6869
loop {
69-
let (obs, log_prob, act, val, rew) = self.get_step_data(&*env, policy);
70+
let (obs, log_prob, act, val, rew, perm_idx) = self.get_step_data(&*env, policy);
7071
obss.push(obs);
7172
log_probs.push(log_prob);
7273
vals.push(val);
7374
rews.push(rew);
7475
acts.push(act);
76+
perms.push(perm_idx);
7577

7678
if env.is_final() { break; }
7779
env.step(act);
@@ -92,6 +94,7 @@ impl PPOCollector {
9294
let mut data = CollectedData::new(
9395
obss,
9496
log_probs,
97+
perms,
9598
vals,
9699
rews,
97100
acts,
@@ -177,4 +180,3 @@ mod tests {
177180
assert!(data.additional_data.contains_key("rets"));
178181
}
179182
}
180-

rust/src/nn/policy.rs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,31 +32,36 @@ impl Policy {
3232
}
3333

3434
pub fn predict(&self, obs: Vec<usize>, masks: Vec<bool>) -> (Vec<f32>, f32) {
35-
// Forward of the action net
36-
let (action_logits, value) = self._raw_predict(obs, self.get_perm_id());
35+
let (exp_masked_probs, value, _) = self.predict_with_perm(obs, masks);
36+
(exp_masked_probs, value)
37+
}
38+
39+
pub fn predict_with_perm(&self, obs: Vec<usize>, masks: Vec<bool>) -> (Vec<f32>, f32, Option<usize>) {
40+
let (action_logits, value, perm_idx) = self.forward_with_perm(obs, masks.clone());
3741

3842
// Apply masks to the actions
3943
let mut exp_masked_probs: Vec<f32> = action_logits.iter().zip(masks.iter()).map(|(&a, &m)| if m {a.exp()} else {0.0}).collect();
4044

41-
// TODO: apply noise to the actions
42-
4345
// Normalize actions
4446
let action_probs_sum: f32 = exp_masked_probs.iter().sum();
4547
exp_masked_probs = exp_masked_probs.iter().map(|&v| v / (action_probs_sum + 0.000001)).collect();
46-
(exp_masked_probs, value)
48+
(exp_masked_probs, value, perm_idx)
4749
}
4850

49-
5051
pub fn forward(&self, obs: Vec<usize>, masks: Vec<bool>) -> (Vec<f32>, f32) {
51-
// Similar to predict but outputs unnormalized logits instead of probabilities
52+
let (masked_logits, value, _) = self.forward_with_perm(obs, masks);
53+
(masked_logits, value)
54+
}
5255

56+
pub fn forward_with_perm(&self, obs: Vec<usize>, masks: Vec<bool>) -> (Vec<f32>, f32, Option<usize>) {
5357
// Forward of the action net
54-
let (action_logits, value) = self._raw_predict(obs, self.get_perm_id());
58+
let perm_idx = self.get_perm_id();
59+
let (action_logits, value) = self._raw_predict(obs, perm_idx);
5560

5661
// Apply masks to the actions
5762
let masked_logits: Vec<f32> = action_logits.iter().zip(masks.iter()).map(|(&a, &m)| if m {a} else {-1e10}).collect();
5863

59-
(masked_logits, value)
64+
(masked_logits, value, perm_idx)
6065
}
6166

6267
fn get_perm_id(&self) -> Option<usize> {

rust/src/python_interface/collector.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,11 @@ impl PyCollectedData {
3737
values: Vec<f32>,
3838
rewards: Vec<f32>,
3939
actions: Vec<usize>,
40+
perms: Option<Vec<Option<usize>>>,
4041
) -> Self {
42+
let perms = perms.unwrap_or_else(|| vec![None; obs.len()]);
4143
PyCollectedData {
42-
inner: CollectedData::new(obs, logits, values, rewards, actions),
44+
inner: CollectedData::new(obs, logits, perms, values, rewards, actions),
4345
}
4446
}
4547

@@ -69,6 +71,19 @@ impl PyCollectedData {
6971
fn set_logits(&mut self, logits: Vec<Vec<f32>>) {
7072
self.inner.logits = logits;
7173
}
74+
75+
#[getter]
76+
fn get_perms(&self) -> Vec<i64> {
77+
self.inner.perms.iter().map(|opt| opt.map(|v| v as i64).unwrap_or(-1)).collect()
78+
}
79+
80+
#[setter]
81+
fn set_perms(&mut self, perms: Vec<i64>) {
82+
self.inner.perms = perms
83+
.into_iter()
84+
.map(|v| if v < 0 { None } else { Some(v as usize) })
85+
.collect();
86+
}
7287

7388
#[getter]
7489
fn get_values(&self) -> Vec<f32> {

0 commit comments

Comments
 (0)