Skip to content

Commit 90f0015

Browse files
committed
Added IDP agent
1 parent 9ed6ae0 commit 90f0015

File tree

6 files changed

+410
-2
lines changed

6 files changed

+410
-2
lines changed

ddopai/_modidx.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,34 @@
366366
'ddopai/agents/dynamic_pricing/greedy.py'),
367367
'ddopai.agents.dynamic_pricing.greedy.GreedyPolicy.update_task': ( '30_agents/42_DP_agents/greedy_agent.html#greedypolicy.update_task',
368368
'ddopai/agents/dynamic_pricing/greedy.py')},
369+
'ddopai.agents.dynamic_pricing.inventory_constrained.IDP': { 'ddopai.agents.dynamic_pricing.inventory_constrained.IDP.IDPAgent': ( '30_agents/42_DP_agents/421_DP_inventory_agents/idp_agent.html#idpagent',
370+
'ddopai/agents/dynamic_pricing/inventory_constrained/IDP.py'),
371+
'ddopai.agents.dynamic_pricing.inventory_constrained.IDP.IDPAgent.__init__': ( '30_agents/42_DP_agents/421_DP_inventory_agents/idp_agent.html#idpagent.__init__',
372+
'ddopai/agents/dynamic_pricing/inventory_constrained/IDP.py'),
373+
'ddopai.agents.dynamic_pricing.inventory_constrained.IDP.IDPAgent.update_task': ( '30_agents/42_DP_agents/421_DP_inventory_agents/idp_agent.html#idpagent.update_task',
374+
'ddopai/agents/dynamic_pricing/inventory_constrained/IDP.py'),
375+
'ddopai.agents.dynamic_pricing.inventory_constrained.IDP.IDPCoreAgent': ( '30_agents/42_DP_agents/421_DP_inventory_agents/idp_agent.html#idpcoreagent',
376+
'ddopai/agents/dynamic_pricing/inventory_constrained/IDP.py'),
377+
'ddopai.agents.dynamic_pricing.inventory_constrained.IDP.IDPCoreAgent.__init__': ( '30_agents/42_DP_agents/421_DP_inventory_agents/idp_agent.html#idpcoreagent.__init__',
378+
'ddopai/agents/dynamic_pricing/inventory_constrained/IDP.py'),
379+
'ddopai.agents.dynamic_pricing.inventory_constrained.IDP.IDPCoreAgent.fit': ( '30_agents/42_DP_agents/421_DP_inventory_agents/idp_agent.html#idpcoreagent.fit',
380+
'ddopai/agents/dynamic_pricing/inventory_constrained/IDP.py'),
381+
'ddopai.agents.dynamic_pricing.inventory_constrained.IDP.IDPCoreAgent.update_task': ( '30_agents/42_DP_agents/421_DP_inventory_agents/idp_agent.html#idpcoreagent.update_task',
382+
'ddopai/agents/dynamic_pricing/inventory_constrained/IDP.py'),
383+
'ddopai.agents.dynamic_pricing.inventory_constrained.IDP.IDPPolicy': ( '30_agents/42_DP_agents/421_DP_inventory_agents/idp_agent.html#idppolicy',
384+
'ddopai/agents/dynamic_pricing/inventory_constrained/IDP.py'),
385+
'ddopai.agents.dynamic_pricing.inventory_constrained.IDP.IDPPolicy.__init__': ( '30_agents/42_DP_agents/421_DP_inventory_agents/idp_agent.html#idppolicy.__init__',
386+
'ddopai/agents/dynamic_pricing/inventory_constrained/IDP.py'),
387+
'ddopai.agents.dynamic_pricing.inventory_constrained.IDP.IDPPolicy.draw_action': ( '30_agents/42_DP_agents/421_DP_inventory_agents/idp_agent.html#idppolicy.draw_action',
388+
'ddopai/agents/dynamic_pricing/inventory_constrained/IDP.py'),
389+
'ddopai.agents.dynamic_pricing.inventory_constrained.IDP.IDPPolicy.fit': ( '30_agents/42_DP_agents/421_DP_inventory_agents/idp_agent.html#idppolicy.fit',
390+
'ddopai/agents/dynamic_pricing/inventory_constrained/IDP.py'),
391+
'ddopai.agents.dynamic_pricing.inventory_constrained.IDP.IDPPolicy.lagrangian': ( '30_agents/42_DP_agents/421_DP_inventory_agents/idp_agent.html#idppolicy.lagrangian',
392+
'ddopai/agents/dynamic_pricing/inventory_constrained/IDP.py'),
393+
'ddopai.agents.dynamic_pricing.inventory_constrained.IDP.IDPPolicy.reset': ( '30_agents/42_DP_agents/421_DP_inventory_agents/idp_agent.html#idppolicy.reset',
394+
'ddopai/agents/dynamic_pricing/inventory_constrained/IDP.py'),
395+
'ddopai.agents.dynamic_pricing.inventory_constrained.IDP.IDPPolicy.update_task': ( '30_agents/42_DP_agents/421_DP_inventory_agents/idp_agent.html#idppolicy.update_task',
396+
'ddopai/agents/dynamic_pricing/inventory_constrained/IDP.py')},
369397
'ddopai.agents.dynamic_pricing.mushroom_rl': { 'ddopai.agents.dynamic_pricing.mushroom_rl.PricingMushroomBaseAgent': ( '30_agents/42_DP_agents/mushroom_base_agent.html#pricingmushroombaseagent',
370398
'ddopai/agents/dynamic_pricing/mushroom_rl.py'),
371399
'ddopai.agents.dynamic_pricing.mushroom_rl.PricingMushroomBaseAgent.__init__': ( '30_agents/42_DP_agents/mushroom_base_agent.html#pricingmushroombaseagent.__init__',

ddopai/agents/class_names.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,7 @@
3838
"TS": "ddopai.agents.dynamic_pricing.TS.TSAgent",
3939
"UCB": "ddopai.agents.dynamic_pricing.UCB.UCBAgent",
4040
"Clairvoyant": "ddopai.agents.dynamic_pricing.clairvoyant.ClairvoyantAgent",
41-
"MTS": "ddopai.agents.dynamic_pricing.MTS.MTSAgent"
41+
"MTS": "ddopai.agents.dynamic_pricing.MTS.MTSAgent",
42+
43+
"IDP": "ddopai.agents.dynamic_pricing.inventory_constrained.IDP.IDPAgent"
4244
}
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
"""Agents that knows the underlying task and the optimal action"""
2+
3+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../../nbs/30_agents/42_DP_agents/421_DP_inventory_agents/10_IDP_agent.ipynb.
4+
5+
# %% auto 0
6+
__all__ = ['IDPPolicy', 'IDPCoreAgent', 'IDPAgent']
7+
8+
# %% ../../../../nbs/30_agents/42_DP_agents/421_DP_inventory_agents/10_IDP_agent.ipynb 3
9+
import logging
10+
11+
from abc import ABC, abstractmethod
12+
from typing import Union, Optional, List
13+
import numpy as np
14+
import joblib
15+
import os
16+
import statsmodels.api as sm
17+
from ..utils import GLMLink
18+
from ....envs.base import BaseEnvironment
19+
from ..mushroom_rl import PricingMushroomBaseAgent
20+
from mushroom_rl.core import Agent
21+
from ....utils import MDPInfo
22+
from ...obsprocessors import FlattenTimeDimNumpy
23+
from ....envs.actionprocessors import ClipAction
24+
25+
# %% ../../../../nbs/30_agents/42_DP_agents/421_DP_inventory_agents/10_IDP_agent.ipynb 4
26+
class IDPPolicy():
27+
def __init__(self,
28+
environment_info: MDPInfo,
29+
obsprocessors: Optional[List[object]] = None,
30+
actionprocessors: Optional[List[object]] = None,
31+
agent_name: str | None = None,
32+
task: dict = None,
33+
price_function = None,
34+
g = None,
35+
):
36+
37+
alpha = np.array(task["alpha"])
38+
beta = np.array(task["beta"])
39+
assert type(alpha) == type(beta), "alpha and beta must be of the same type"
40+
if type(alpha) == None:
41+
alpha = np.zeros(environment_info.observation_space['features'].shape[0])
42+
beta = np.zeros(environment_info.observation_space['features'].shape[0])
43+
self.environment_info = environment_info
44+
self.task = task
45+
self.T = task["horizon"]
46+
self.alpha = alpha
47+
self.beta = beta
48+
if environment_info.observation_space['features'].shape[0] == 1:
49+
self.E_X = np.array([1])
50+
else:
51+
self.E_X = np.full(environment_info.observation_space['features'].shape[0], 1 / (2 * np.sqrt(environment_info.observation_space['features'].shape[0])))
52+
self.actionprocessors = actionprocessors
53+
self.price_function = price_function # Needs to return an np array
54+
self.g = g
55+
self.t = 0
56+
self.mode = "train"
57+
self.actionprocessors.append(ClipAction(environment_info.action_space.low, environment_info.action_space.high))
58+
59+
def draw_action(self, observation: np.ndarray):
60+
X = observation['features']
61+
B_t = observation['Inventory']
62+
price = self.price_function(X, self.alpha, self.beta)
63+
lagrangian = self.lagrangian(B_t)
64+
price = price + lagrangian
65+
for processor in self.actionprocessors:
66+
price = processor(price)
67+
68+
return np.array(price)
69+
70+
def lagrangian(self, B_t):
71+
"""
72+
Lagrangian function for the pricing problem
73+
"""
74+
avg_remaining_B = (2 * B_t) / (self.T - self.t +1)
75+
lagrangian = (avg_remaining_B - np.dot(self.alpha, self.E_X)) / np.dot(self.beta, self.E_X)
76+
return lagrangian
77+
def update_task(self, env):
78+
self.environment_info = env.mdp_info
79+
self.task = env.get_task()
80+
81+
self.alpha = np.array(self.task["alpha"])
82+
self.beta = np.array(self.task["beta"])
83+
if self.environment_info.observation_space['features'].shape[0] == 1:
84+
self.E_X = np.array([1])
85+
else:
86+
self.E_X = np.full(self.environment_info.observation_space['features'].shape[0], 1 / (2 * np.sqrt(self.environment_info.observation_space['features'].shape[0])))
87+
self.T = self.task["horizon"]
88+
89+
"""TODO add change in price function"""
90+
def fit(self, X, Y, action):
91+
self.t += 1
92+
93+
94+
def reset(self):
95+
pass
96+
97+
98+
# %% ../../../../nbs/30_agents/42_DP_agents/421_DP_inventory_agents/10_IDP_agent.ipynb 5
99+
class IDPCoreAgent(Agent):
100+
101+
"""
102+
Base class for clairvoyant agents.
103+
"""
104+
105+
def __init__(self,
106+
environment_info: MDPInfo,
107+
obsprocessors: Optional[List[object]] = [],
108+
actionprocessors: Optional[List[object]] = [],
109+
agent_name: str | None = None,
110+
task: dict = None,
111+
price_function = None,
112+
g = None,
113+
):
114+
115+
policy = IDPPolicy(environment_info=environment_info, obsprocessors=obsprocessors, actionprocessors=actionprocessors, task=task, price_function=price_function, g=g)
116+
self.agent_name = agent_name
117+
super().__init__(environment_info, policy)
118+
119+
def fit(self, dataset, **kwargs):
120+
X = dataset[0][0]["features"]
121+
Y = kwargs["demand"][0]
122+
action = dataset[0][1]
123+
self.policy.fit(X, Y, action)
124+
125+
def update_task(self, env):
126+
self.policy.update_task(env)
127+
128+
129+
130+
131+
# %% ../../../../nbs/30_agents/42_DP_agents/421_DP_inventory_agents/10_IDP_agent.ipynb 6
132+
class IDPAgent(PricingMushroomBaseAgent):
133+
"""
134+
Wrapper class for IDPCoreAgent to interact with MushroomRL.
135+
"""
136+
def __init__(self,
137+
environment_info: MDPInfo,
138+
obsprocessors: Optional[List[object]] =[],
139+
actionprocessors: Optional[List[object]] = [],
140+
agent_name: str | None = None,
141+
task: dict = None,
142+
price_function = None,
143+
g = None,
144+
):
145+
self.agent = IDPCoreAgent(environment_info = environment_info,
146+
obsprocessors = obsprocessors,
147+
actionprocessors = actionprocessors,
148+
agent_name = agent_name,
149+
task=task,
150+
price_function = price_function,
151+
g = g)
152+
super().__init__(environment_info = environment_info, obsprocessors = obsprocessors, agent_name = agent_name)
153+
def update_task(self, env: object):
154+
""" Update the environment specific parameters of the agent """
155+
self.agent.update_task(env)
156+

ddopai/agents/dynamic_pricing/inventory_constrained/__init__.py

Whitespace-only changes.

nbs/30_agents/40_base_agents/10_AGENT_CLASSES.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@
6767
" \"TS\": \"ddopai.agents.dynamic_pricing.TS.TSAgent\",\n",
6868
" \"UCB\": \"ddopai.agents.dynamic_pricing.UCB.UCBAgent\",\n",
6969
" \"Clairvoyant\": \"ddopai.agents.dynamic_pricing.clairvoyant.ClairvoyantAgent\",\n",
70-
" \"MTS\": \"ddopai.agents.dynamic_pricing.MTS.MTSAgent\"\n",
70+
" \"MTS\": \"ddopai.agents.dynamic_pricing.MTS.MTSAgent\",\n",
71+
" \n",
72+
" \"IDP\": \"ddopai.agents.dynamic_pricing.inventory_constrained.IDP.IDPAgent\"\n",
7173
"}"
7274
]
7375
},

0 commit comments

Comments
 (0)