Skip to content

Commit 9fcae36

Browse files
committed
Fixed idx bug for val and test dataloader, changed clairvoyant agent to work with a task
1 parent 3d2a1a3 commit 9fcae36

File tree

7 files changed

+51
-25
lines changed

7 files changed

+51
-25
lines changed

ddopai/_modidx.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,6 +1170,8 @@
11701170
'ddopai/envs/pricing/base.py'),
11711171
'ddopai.envs.pricing.base.BasePricingEnv.get_observation': ( '20_environments/22_envs_pricing/base_pricing_env.html#basepricingenv.get_observation',
11721172
'ddopai/envs/pricing/base.py'),
1173+
'ddopai.envs.pricing.base.BasePricingEnv.get_task': ( '20_environments/22_envs_pricing/base_pricing_env.html#basepricingenv.get_task',
1174+
'ddopai/envs/pricing/base.py'),
11731175
'ddopai.envs.pricing.base.BasePricingEnv.reset': ( '20_environments/22_envs_pricing/base_pricing_env.html#basepricingenv.reset',
11741176
'ddopai/envs/pricing/base.py'),
11751177
'ddopai.envs.pricing.base.BasePricingEnv.reset_index': ( '20_environments/22_envs_pricing/base_pricing_env.html#basepricingenv.reset_index',

ddopai/agents/dynamic_pricing/clairvoyant.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,19 @@ def __init__(self,
2929
obsprocessors: Optional[List[object]] = None,
3030
actionprocessors: Optional[List[object]] = None,
3131
agent_name: str | None = None,
32-
alpha: np.ndarray | None = None,
33-
beta: np.ndarray | None = None,
32+
task: dict = None,
3433
price_function = None,
3534
g = None,
3635
):
36+
37+
alpha = np.array(task["alpha"])
38+
beta = np.array(task["beta"])
3739
assert type(alpha) == type(beta), "alpha and beta must be of the same type"
3840
if type(alpha) == None:
3941
alpha = np.zeros(environment_info.observation_space['features'].shape[0])
4042
beta = np.zeros(environment_info.observation_space['features'].shape[0])
4143
self.environment_info = environment_info
44+
self.task = task
4245
self.alpha = alpha
4346
self.beta = beta
4447
self.actionprocessors = actionprocessors
@@ -60,6 +63,8 @@ def draw_action(self, observation: np.ndarray):
6063

6164
def update_env(self, env):
6265
self.environment_info = env.mdp_info
66+
self.task = env.get_task()
67+
6368
self.alpha = env.alpha
6469
self.beta = env.beta
6570
"""TODO add change in price function"""
@@ -83,13 +88,12 @@ def __init__(self,
8388
obsprocessors: Optional[List[object]] = [],
8489
actionprocessors: Optional[List[object]] = [],
8590
agent_name: str | None = None,
86-
alpha: np.ndarray | None = None,
87-
beta: np.ndarray | None = None,
91+
task: dict = None,
8892
price_function = None,
8993
g = None,
9094
):
9195

92-
policy = ClairvoyantPolicy(environment_info=environment_info, obsprocessors=obsprocessors, actionprocessors=actionprocessors, alpha=alpha, beta=beta, price_function=price_function, g=g)
96+
policy = ClairvoyantPolicy(environment_info=environment_info, obsprocessors=obsprocessors, actionprocessors=actionprocessors, task=task, price_function=price_function, g=g)
9397
self.agent_name = agent_name
9498
super().__init__(environment_info, policy)
9599

@@ -115,17 +119,15 @@ def __init__(self,
115119
obsprocessors: Optional[List[object]] =[],
116120
actionprocessors: Optional[List[object]] = [],
117121
agent_name: str | None = None,
118-
alpha: np.ndarray | None = None,
119-
beta: np.ndarray | None = None,
122+
task: dict = None,
120123
price_function = None,
121124
g = None,
122125
):
123126
self.agent = ClairvoyantCoreAgent(environment_info = environment_info,
124127
obsprocessors = obsprocessors,
125128
actionprocessors = actionprocessors,
126129
agent_name = agent_name,
127-
alpha = alpha,
128-
beta = beta,
130+
task=task,
129131
price_function = price_function,
130132
g = g)
131133
super().__init__(environment_info = environment_info, obsprocessors = obsprocessors, agent_name = agent_name)

ddopai/dataloaders/online.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(self,
4040
normalize_features: dict = None,
4141
):
4242
self.X = X
43+
self.Y = epsilon
4344
self.alpha = alpha
4445
self.beta = beta
4546
self.epsilon = epsilon
@@ -196,14 +197,14 @@ def probit(X, action):
196197
return np.maximum(demand, 0)
197198
return probit
198199

199-
def __getitem__(self, index: int):
200+
def __getitem__(self, idx: int):
200201

201202
"""
202203
get item by index, depending on the dataset type (train, val, test)
203204
"""
204205

205206
if self.dataset_type == "train":
206-
if index > self.train_index_end:
207+
if idx > self.train_index_end:
207208
raise IndexError('Index out of bounds')
208209

209210
elif self.dataset_type == "val":
@@ -225,7 +226,7 @@ def __getitem__(self, index: int):
225226
else:
226227
raise ValueError('dataset_type not set')
227228

228-
return self.X[index], self._get_Y(index)
229+
return self.X[idx], self._get_Y(idx)
229230

230231
def __len__(self):
231232
return len(self.X)

ddopai/envs/pricing/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,12 @@ def test(self, update_mdp_info=True):
245245
self.update_mdp_info(gamma=self.mdp_info.gamma, horizon=self.mdp_info.horizon)
246246

247247
self.reset()
248+
249+
def get_task(self):
250+
"""
251+
Return the current task. This function is for the online learning case it will return only the state,
252+
this function should be overwritten.
253+
254+
"""
255+
256+
return self.task.copy()

nbs/10_dataloaders/13_online_dataloaders.ipynb

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
" normalize_features: dict = None,\n",
7777
" ):\n",
7878
" self.X = X\n",
79+
" self.Y = epsilon\n",
7980
" self.alpha = alpha\n",
8081
" self.beta = beta\n",
8182
" self.epsilon = epsilon\n",
@@ -232,14 +233,14 @@
232233
" return np.maximum(demand, 0)\n",
233234
" return probit\n",
234235
"\n",
235-
" def __getitem__(self, index: int):\n",
236+
" def __getitem__(self, idx: int):\n",
236237
" \n",
237238
" \"\"\"\n",
238239
" get item by index, depending on the dataset type (train, val, test)\n",
239240
" \"\"\"\n",
240241
" \n",
241242
" if self.dataset_type == \"train\":\n",
242-
" if index > self.train_index_end:\n",
243+
" if idx > self.train_index_end:\n",
243244
" raise IndexError('Index out of bounds')\n",
244245
" \n",
245246
" elif self.dataset_type == \"val\":\n",
@@ -261,7 +262,7 @@
261262
" else:\n",
262263
" raise ValueError('dataset_type not set')\n",
263264
"\n",
264-
" return self.X[index], self._get_Y(index)\n",
265+
" return self.X[idx], self._get_Y(idx)\n",
265266
" \n",
266267
" def __len__(self):\n",
267268
" return len(self.X)\n",

nbs/20_environments/22_envs_pricing/10_base_pricing_env.ipynb

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,16 @@
280280
" if update_mdp_info:\n",
281281
" self.update_mdp_info(gamma=self.mdp_info.gamma, horizon=self.mdp_info.horizon)\n",
282282
"\n",
283-
" self.reset()"
283+
" self.reset()\n",
284+
" \n",
285+
" def get_task(self):\n",
286+
" \"\"\"\n",
287+
" Return the current task. This function is for the online learning case it will return only the state,\n",
288+
" this function should be overwritten.\n",
289+
"\n",
290+
" \"\"\"\n",
291+
"\n",
292+
" return self.task.copy()"
284293
]
285294
}
286295
],

nbs/30_agents/42_DP_agents/10_clairvoyant_agent.ipynb

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,19 @@
6666
" obsprocessors: Optional[List[object]] = None,\n",
6767
" actionprocessors: Optional[List[object]] = None,\n",
6868
" agent_name: str | None = None,\n",
69-
" alpha: np.ndarray | None = None,\n",
70-
" beta: np.ndarray | None = None,\n",
69+
" task: dict = None,\n",
7170
" price_function = None,\n",
7271
" g = None,\n",
7372
" ):\n",
73+
" \n",
74+
" alpha = np.array(task[\"alpha\"])\n",
75+
" beta = np.array(task[\"beta\"])\n",
7476
" assert type(alpha) == type(beta), \"alpha and beta must be of the same type\"\n",
7577
" if type(alpha) == None:\n",
7678
" alpha = np.zeros(environment_info.observation_space['features'].shape[0]) \n",
7779
" beta = np.zeros(environment_info.observation_space['features'].shape[0])\n",
7880
" self.environment_info = environment_info\n",
81+
" self.task = task\n",
7982
" self.alpha = alpha\n",
8083
" self.beta = beta\n",
8184
" self.actionprocessors = actionprocessors\n",
@@ -97,6 +100,8 @@
97100
" \n",
98101
" def update_env(self, env):\n",
99102
" self.environment_info = env.mdp_info\n",
103+
" self.task = env.get_task()\n",
104+
" \n",
100105
" self.alpha = env.alpha\n",
101106
" self.beta = env.beta\n",
102107
" \"\"\"TODO add change in price function\"\"\"\n",
@@ -127,13 +132,12 @@
127132
" obsprocessors: Optional[List[object]] = [],\n",
128133
" actionprocessors: Optional[List[object]] = [],\n",
129134
" agent_name: str | None = None,\n",
130-
" alpha: np.ndarray | None = None,\n",
131-
" beta: np.ndarray | None = None,\n",
135+
" task: dict = None,\n",
132136
" price_function = None,\n",
133137
" g = None,\n",
134138
" ):\n",
135139
" \n",
136-
" policy = ClairvoyantPolicy(environment_info=environment_info, obsprocessors=obsprocessors, actionprocessors=actionprocessors, alpha=alpha, beta=beta, price_function=price_function, g=g)\n",
140+
" policy = ClairvoyantPolicy(environment_info=environment_info, obsprocessors=obsprocessors, actionprocessors=actionprocessors, task=task, price_function=price_function, g=g)\n",
137141
" self.agent_name = agent_name\n",
138142
" super().__init__(environment_info, policy)\n",
139143
" \n",
@@ -166,17 +170,15 @@
166170
" obsprocessors: Optional[List[object]] =[],\n",
167171
" actionprocessors: Optional[List[object]] = [],\n",
168172
" agent_name: str | None = None,\n",
169-
" alpha: np.ndarray | None = None,\n",
170-
" beta: np.ndarray | None = None,\n",
173+
" task: dict = None,\n",
171174
" price_function = None,\n",
172175
" g = None,\n",
173176
" ):\n",
174177
" self.agent = ClairvoyantCoreAgent(environment_info = environment_info,\n",
175178
" obsprocessors = obsprocessors, \n",
176179
" actionprocessors = actionprocessors, \n",
177180
" agent_name = agent_name, \n",
178-
" alpha = alpha, \n",
179-
" beta = beta, \n",
181+
" task=task,\n",
180182
" price_function = price_function, \n",
181183
" g = g)\n",
182184
" super().__init__(environment_info = environment_info, obsprocessors = obsprocessors, agent_name = agent_name)\n",

0 commit comments

Comments
 (0)