Skip to content

Commit 52e1d7f

Browse files
committed
Portet the main metalearning infrastructure form the hyper repo
1 parent 90f0015 commit 52e1d7f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+15949
-82
lines changed

ddopai/_modidx.py

Lines changed: 579 additions & 27 deletions
Large diffs are not rendered by default.

ddopai/agents/dynamic_pricing/inventory_constrained/IDP.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(self,
5858

5959
def draw_action(self, observation: np.ndarray):
6060
X = observation['features']
61-
B_t = observation['Inventory']
61+
B_t = observation['inventory']
6262
price = self.price_function(X, self.alpha, self.beta)
6363
lagrangian = self.lagrangian(B_t)
6464
price = price + lagrangian
@@ -74,6 +74,7 @@ def lagrangian(self, B_t):
7474
avg_remaining_B = (2 * B_t) / (self.T - self.t +1)
7575
lagrangian = (avg_remaining_B - np.dot(self.alpha, self.E_X)) / np.dot(self.beta, self.E_X)
7676
return lagrangian
77+
7778
def update_task(self, env):
7879
self.environment_info = env.mdp_info
7980
self.task = env.get_task()
@@ -85,6 +86,7 @@ def update_task(self, env):
8586
else:
8687
self.E_X = np.full(self.environment_info.observation_space['features'].shape[0], 1 / (2 * np.sqrt(self.environment_info.observation_space['features'].shape[0])))
8788
self.T = self.task["horizon"]
89+
self.t = 0
8890

8991
"""TODO add change in price function"""
9092
def fit(self, X, Y, action):

ddopai/agents/ml_utils.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,18 @@
33
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/30_agents/40_ml_utils.ipynb.
44

55
# %% auto 0
6-
__all__ = ['LRSchedulerPerStep']
6+
__all__ = ['LRSchedulerPerStep', 'init_gru', 'init_module', 'init_mlp']
77

88
# %% ../../nbs/30_agents/40_ml_utils.ipynb 3
99
from typing import List, Tuple, Literal
1010
import torch
1111

12+
from typing import Callable, List
13+
14+
import numpy as np
15+
from torch import nn as nn
16+
from torch.nn.utils import weight_norm
17+
1218
# %% ../../nbs/30_agents/40_ml_utils.ipynb 4
1319
class LRSchedulerPerStep():
1420
"""
@@ -39,3 +45,83 @@ def step(self):
3945

4046
for param_group in self.optimizer.param_groups:
4147
param_group['lr'] = lr
48+
49+
# %% ../../nbs/30_agents/40_ml_utils.ipynb 5
50+
def init_gru(input_size: int, recurrent_state_size: int) -> nn.Module:
51+
"""
52+
Initialize a GRU module.
53+
54+
Args:
55+
input_size (int): Input size to the GRU.
56+
recurrent_state_size (int): Recurrent state size for the GRU.
57+
58+
Returns:
59+
nn.Module
60+
"""
61+
gru = nn.GRU(input_size, recurrent_state_size)
62+
63+
for name, param in gru.named_parameters():
64+
if "bias" in name:
65+
nn.init.constant_(param, 0)
66+
elif "weight" in name:
67+
nn.init.orthogonal_(param)
68+
69+
return gru
70+
71+
def init_module(
72+
module: nn.Module, weight_init: Callable, bias_init: Callable, gain: float = 1.0
73+
) -> nn.Module:
74+
"""
75+
Initialize a module with the given weight and bias functions.
76+
77+
Args:
78+
module (nn.Module): Module that is to be initialized with the given weight and bias.
79+
weight_init (Callable): Function for initializing weights.
80+
bias_init (Callable): Function for initialize biases.
81+
gain (float): Gain amount.
82+
83+
Returns:
84+
nn.Module
85+
"""
86+
weight_init(module.weight.data, gain=gain)
87+
bias_init(module.bias.data)
88+
weight_norm(module)
89+
90+
return module
91+
92+
93+
def init_mlp(input_size: int, hidden_sizes: List[int]) -> nn.Sequential:
94+
"""
95+
Initialize the value head for the critic.
96+
97+
Args:
98+
input_size (List[int]): Size of the recurrent state in the base RNN.
99+
hidden_sizes (List[int]): Sizes of the hidden layers of the MLP.
100+
101+
Returns:
102+
nn.Sequential
103+
"""
104+
105+
def _init_orthogonal(m: nn.Module):
106+
return init_module(
107+
m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
108+
)
109+
110+
feature_sizes = list([input_size])
111+
feature_sizes.extend(hidden_sizes)
112+
113+
mlp_modules = list()
114+
for i in range(len(feature_sizes) - 1):
115+
hidden_layer = _init_orthogonal(
116+
nn.Linear(feature_sizes[i], feature_sizes[i + 1])
117+
)
118+
119+
# zero bias
120+
torch.nn.init.zeros_(hidden_layer.bias)
121+
mlp_modules.append(hidden_layer)
122+
123+
# relu
124+
mlp_modules.append(nn.ReLU())
125+
pass
126+
127+
return nn.Sequential(*mlp_modules)

ddopai/agents/rl/RL2ppo.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""PPO based agent"""
22

3-
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/30_agents/51_RL_agents/10_RL2PPO_agents.ipynb.
3+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/30_agents/51_RL_agents/10_RL2_agents.ipynb.
44

55
# %% auto 0
66
__all__ = ['GaussianTorchPolicyRL2', 'RL2PPO', 'RL2PPOAgent']
77

8-
# %% ../../../nbs/30_agents/51_RL_agents/10_RL2PPO_agents.ipynb 4
8+
# %% ../../../nbs/30_agents/51_RL_agents/10_RL2_agents.ipynb 4
99
import logging
1010

1111
# set logging level to INFO
@@ -44,7 +44,7 @@
4444
from itertools import chain
4545
import time
4646

47-
# %% ../../../nbs/30_agents/51_RL_agents/10_RL2PPO_agents.ipynb 5
47+
# %% ../../../nbs/30_agents/51_RL_agents/10_RL2_agents.ipynb 6
4848
class GaussianTorchPolicyRL2(TorchPolicy):
4949
"""
5050
Torch policy implementing a Gaussian policy with trainable standard
@@ -176,7 +176,7 @@ def parameters(self):
176176
return chain(self._mu.model.network.parameters(), [self._log_sigma])
177177

178178

179-
# %% ../../../nbs/30_agents/51_RL_agents/10_RL2PPO_agents.ipynb 6
179+
# %% ../../../nbs/30_agents/51_RL_agents/10_RL2_agents.ipynb 7
180180
class RL2PPO(Agent):
181181
"""
182182
Proximal Policy Optimization (PPO) Agent supporting sequential data and RL² compatibility.
@@ -479,7 +479,7 @@ def get_tensor(field, dtype=None):
479479
self.standardize_advantages(meta_episodes)
480480
# Convert stored fields to torch tensors.
481481
mb_obs = get_tensor("obs") # shape: (B, T, obs_dim)
482-
mb_acs = get_tensor("acs", "long") # shape: (B, T, action_dim)
482+
mb_acs = get_tensor("acs") # shape: (B, T, action_dim)
483483
mb_rews = get_tensor("rews") # shape: (B, T)
484484
mb_dones = get_tensor("dones") # shape: (B, T)
485485
mb_logpacs = get_tensor("logpacs") # shape: (B, T) or (B, T, 1)
@@ -541,7 +541,7 @@ def _post_load(self):
541541
update_optimizer_parameters(self._optimizer, list(self.policy.parameters()))
542542

543543

544-
# %% ../../../nbs/30_agents/51_RL_agents/10_RL2PPO_agents.ipynb 7
544+
# %% ../../../nbs/30_agents/51_RL_agents/10_RL2_agents.ipynb 8
545545
class RL2PPOAgent(MushroomBaseAgent):
546546
"""
547547
RL² PPO Agent for meta-learning, based on recurrent policy/value networks and MushroomRL core agent.
@@ -682,13 +682,7 @@ def reset_hidden(self, batch_size=1, device='cpu'):
682682
device (str): device where hidden states should be allocated ('cpu' or 'cuda').
683683
"""
684684
self.agent.reset_hidden_state(batch_size=batch_size, device=device)
685-
# Reset actor hidden state
686-
#actor_network = self.agent.policy._mu._impl.model.network
687-
#actor_network.hidden_state = actor_network.model.rnn.model.init_hidden(batch_size=batch_size, device=device)
688-
689-
# Reset critic hidden state
690-
#critic_network = self.agent._V._impl.model.network
691-
#critic_network.hidden_state = critic_network.model.rnn.model.init_hidden(batch_size=batch_size, device=device)
685+
692686

693687
def predict_(self, observation: np.ndarray) -> np.ndarray:
694688
"""

ddopai/envs/pricing/dynamic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,10 @@ def step_(self,
132132
truncated = self.set_index()
133133

134134
info = dict(
135-
inv=self.inv * self.relative_inv,
135+
inv=(self.inv * self.relative_inv)[0],
136136
demand=demand,
137137
true_demand=true_demand,
138-
action=action.copy(),
138+
action=action[0],
139139
reward=reward,
140140
true_reward=true_reward,
141141
alpha=alpha,
@@ -163,11 +163,11 @@ def get_observation(self):
163163
Function to get the observation from the dataloader.
164164
"""
165165
x, reward_functions = self.dataloader[self.index]
166-
current_inv = np.array([self.relative_inv], dtype=np.float32)
166+
current_inv = self.relative_inv
167167

168168
observation = {
169169
"features": x,
170-
"inventory": current_inv
170+
"inventory": current_inv * self.inv
171171
}
172172
return observation, reward_functions
173173

@@ -191,7 +191,7 @@ def update_episode_params(self):
191191
"""
192192
Update the parameters of the episode.
193193
"""
194-
inv = np.array(self.task["inv_level"])
194+
inv = np.array([self.task["inv_level"]])
195195
relative_inv = np.ones_like(inv, dtype=np.float32)
196196
if hasattr(self, "inv"):
197197
self.set_param("inv", inv, inv.shape, new=False)

ddopai/experiments/experiment_functions_meta.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/40_experiments/10_experiment_functions_meta.ipynb.
44

55
# %% auto 0
6-
__all__ = ['EarlyStoppingHandler', 'calculate_score', 'log_info', 'log_info_history', 'update_best', 'save_agent', 'test_agent',
7-
'run_test_episode', 'run_experiment']
6+
__all__ = ['EarlyStoppingHandler', 'calculate_score', 'log_info', 'log_info_history', 'log_figure_from_history', 'update_best',
7+
'save_agent', 'test_agent', 'run_test_episode', 'run_experiment']
88

99
# %% ../../nbs/40_experiments/10_experiment_functions_meta.ipynb 3
1010
from abc import ABC, abstractmethod
@@ -14,7 +14,8 @@
1414
import numpy as np
1515
import sys
1616
import wandb
17-
17+
import matplotlib.pyplot as plt
18+
import seaborn as sns
1819
from ..envs.base import BaseEnvironment
1920
from ..agents.base import BaseAgent
2021

@@ -124,13 +125,55 @@ def log_info(R: float,
124125
def log_info_history(info: list,
125126
episode: int,
126127
tracking: Literal["wandb"], # only wandb implemented so far
127-
mode: Literal["train", "val", "test"]
128-
):
128+
mode: Literal["train", "val", "test"],
129+
commit: bool = False):
129130
if tracking == "wandb":
130-
table = wandb.Table(columns=["t", "action", "reward", "true_reward", "alpha", "beta", "episode"])
131+
table = wandb.Table(columns=["t", "action", "reward", "true_reward", "alpha", "beta", "inventory", "episode"])
131132
for t, row in enumerate(info):
132-
table.add_data(t, row["action"], row["reward"], row["true_reward"], row["alpha"], row["beta"], episode)
133-
wandb.log({f"{mode}/info_table": table})
133+
table.add_data(t, row["action"], row["reward"], row["true_reward"], row["alpha"], row["beta"], row["inv"], episode)
134+
wandb.log({f"{mode}/info_table": table}, commit=commit)
135+
136+
def log_figure_from_history(info: list,
137+
episode: int,
138+
tracking: Literal["wandb"], # only wandb implemented so far
139+
mode: Literal["train", "val", "test"],
140+
commit: bool = True
141+
):
142+
if tracking == "wandb":
143+
# Plot reward and true reward over time
144+
plt.figure(figsize=(10, 6))
145+
sns.lineplot(x=list(range(len(info))), y=[row["reward"] for row in info], label="Reward")
146+
sns.lineplot(x=list(range(len(info))), y=[row["true_reward"] for row in info], label="True Reward")
147+
plt.title("Reward and True Reward over time")
148+
plt.xlabel("T")
149+
plt.ylabel("Reward")
150+
plt.legend()
151+
reward_fig = plt.gcf()
152+
wandb.log({f"{mode}/reward_over_time_image": wandb.Image(reward_fig)}, commit=False)
153+
plt.close(reward_fig)
154+
155+
# Plot action over time
156+
plt.figure(figsize=(10, 6))
157+
sns.lineplot(x=list(range(len(info))), y=[row["action"] for row in info], label="Action")
158+
plt.title("Action over time")
159+
plt.xlabel("T")
160+
plt.ylabel("Action")
161+
plt.legend()
162+
action_fig = plt.gcf()
163+
wandb.log({f"{mode}/action_over_time_image": wandb.Image(action_fig)}, commit=False)
164+
plt.close(action_fig)
165+
166+
# Plot inventory over time
167+
plt.figure(figsize=(10, 6))
168+
sns.lineplot(x=list(range(len(info))), y=[row["inv"] for row in info], label="Inventory")
169+
plt.title("Inventory over time")
170+
plt.xlabel("T")
171+
plt.ylabel("Inventory")
172+
plt.legend()
173+
inventory_fig = plt.gcf()
174+
wandb.log({f"{mode}/inventory_over_time_image": wandb.Image(inventory_fig)}, commit=commit)
175+
plt.close(inventory_fig)
176+
134177

135178
def update_best(R: float, J: float, best_R: float, best_J: float): #
136179

@@ -209,7 +252,8 @@ def test_agent(agent: BaseAgent,
209252
if tracking == "wandb":
210253
mode = env.mode
211254
wandb.log({f"{mode}/Episode":episode,f"{mode}/R": R, f"{mode}/J": J}, commit=False)
212-
log_info_history([ep_d[1] for ep_d in episode_dataset], episode, tracking, mode)
255+
log_info_history([ep_d[1] for ep_d in episode_dataset], episode, tracking, mode, commit=False)
256+
log_figure_from_history([ep_d[1] for ep_d in episode_dataset], episode, tracking, mode, commit=True)
213257
if return_dataset:
214258
return np.mean(list_R), np.mean(list_J), dataset
215259
else:
@@ -362,7 +406,8 @@ def run_experiment( agent: BaseAgent,
362406
R_list.append(R)
363407
J_list.append(J)
364408
wandb.log({f"test/R": R, f"test/J": J}, commit=False)
365-
log_info_history(env.get_info(), episode, tracking, "test")
409+
log_info_history(env.get_info(), episode, tracking, "test", commit=False)
410+
log_figure_from_history(env.get_info(), episode, tracking, "test", commit=True)
366411
if ((episode+1) % print_freq) == 0:
367412
logging.info(f"Episode {episode+1}: R={R}, J={J}")
368413

ddopai/meta_learning/__init__.py

Whitespace-only changes.

ddopai/meta_learning/algorithms/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)