Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added 2x2 phi traj CP.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added phi_trajectories.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
47 changes: 47 additions & 0 deletions pufferlib/config/ocean/rubiks.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
[base]
package = ocean
env_name = puffer_rubiks
policy_name = Policy
rnn_name = Recurrent

[env]
num_envs = 4096

[train]

adam_beta1= 0.4999999999999999
adam_beta2= 0.999497290393837
adam_eps= 1.092659057939667e-08
anneal_lr= True
batch_size= auto
bptt_horizon= 64
checkpoint_interval= 200
clip_coef= 0.12449250364976959
compile= False
compile_fullgraph=True
compile_mode= max-autotune-no-cudagraphs
cpu_offload= False
data_dir= experiments
device= cpu
ent_coef= 0.20000000000000004
gae_lambda= 0.8797374705059637
gamma= 0.9969927707900579
learning_rate= 0.1
max_grad_norm= 1.6074187450788373
max_minibatch_size= 32768
minibatch_size= 65536
name= pufferai
optimizer= muon
precision= float32
prio_alpha= 0.956020391561609
prio_beta0= 0.9129611672660245
project= ablations
seed= 42
torch_deterministic= True
total_timesteps= 6.916886699061722e+07
update_epochs= 1
use_rnn= True
vf_clip_coef= 0.1
vf_coef= 0.982370686402245
vtrace_c_clip= 0
vtrace_rho_clip= 0.28767080539864404
1 change: 1 addition & 0 deletions pufferlib/ocean/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def make_multiagent(buf=None, **kwargs):
'checkers': 'Checkers',
'asteroids': 'Asteroids',
'whisker_racer': 'WhiskerRacer',
'rubiks': 'Cube',
'onestateworld': 'World',
'onlyfish': 'OnlyFish',
'chain_mdp': 'Chain',
Expand Down
22 changes: 22 additions & 0 deletions pufferlib/ocean/rubiks/binding.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "rubiks.h"

#define Env Cube
#include "../env_binding.h"

static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
env->N = (int) unpack(kwargs, "N");
env->shuffles = (int) unpack(kwargs, "shuffles");
env->size = (int) unpack(kwargs, "size");
env->max_episode_steps = (int) unpack(kwargs, "max_episode_steps");
env->anim_time = (float) unpack(kwargs, "anim_time");
init(env);
return 0;
}

static int my_log(PyObject* dict, Log* log) {
assign_to_dict(dict, "perf", log->perf);
assign_to_dict(dict, "score", log->score);
assign_to_dict(dict, "episode_return", log->episode_return);
assign_to_dict(dict, "episode_length", log->episode_length);
return 0;
}
80 changes: 80 additions & 0 deletions pufferlib/ocean/rubiks/rubiks.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#include "rubiks.h"
#include <unistd.h>
#include <string.h>
#include "puffernet.h"

//Specific functions for user mode only

//To convert highlights to actions
static inline int axis_layer_to_face(int axis, int layer, int N) {
int outer = (layer == N-1); // 1 if the positive side slab
switch (axis) {
case 0: return outer ? R : L; // +X is R, -X is L
case 1: return outer ? U : D; // +Y is U, -Y is D
case 2: return outer ? F : B; // +Z is F, -Z is B
default: return -1;
}
}

static inline int face_dir_to_action(int face, int cw) {
// decode_action: even -> +1 turn, odd -> -1 turn
// treat cw as +1
return face * 2 + (cw ? 0 : 1);
}

// Directly from highlight to action
static inline int highlight_to_action(const Cube *env, int cw) {
int face = axis_layer_to_face(env->highlight_axis, env->highlight_layer, env->N);
return face < 0 ? -1 : face_dir_to_action(face, cw);
}

int main() {
int N = 3;
int num_obs = 6*N*N*6;


Cube env = {
.N = N,
.shuffles = 0,
.size = num_obs
};
init(&env);



env.observations = calloc(num_obs, sizeof(float));
env.actions = calloc(12, sizeof(int));
env.rewards = calloc(1, sizeof(float));
env.terminals = calloc(1, sizeof(unsigned char));
env.max_episode_steps = 1000;


c_reset(&env);
c_render(&env);

env.user_mode = 1;
while (!WindowShouldClose()) {
c_render(&env);

if (IsKeyPressed(KEY_ENTER)) { // CW
int a = highlight_to_action(&env, 1);
env.actions[0] = a;
c_step(&env);
}
if (IsKeyPressed(KEY_BACKSPACE)) { // CCW
int a = highlight_to_action(&env, 0);
env.actions[0] = a;
c_step(&env);
}
}


free(env.observations);
free(env.actions);
free(env.rewards);
free(env.terminals);
c_close(&env);
printf("Done\n");

}

Loading
Loading