Skip to content

Commit 2977097

Browse files
Jinyu-WlihuoranJinyu Wang
authored
Add DDQN (#598)
* Refine explore strategy, add prioritized sampling support; add DDQN example; add DQN test (#590) * Runnable. Should setup a benchmark and test performance. * Refine logic * Test DQN on GYM passed * Refine explore strategy * Minor * Minor * Add Dueling DQN in CIM scenario * Resolve PR comments * Add one more explanation * fix env_sampler eval info list issue * update version to 0.3.2a4 --------- Co-authored-by: Huoran Li <huoranli@microsoft.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>
1 parent b3c6a58 commit 2977097

File tree

25 files changed

+541
-334
lines changed

25 files changed

+541
-334
lines changed

examples/cim/rl/algorithms/dqn.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT license.
3+
from typing import Optional, Tuple
34

45
import torch
56
from torch.optim import RMSprop
67

7-
from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy
8+
from maro.rl.exploration import EpsilonGreedy
89
from maro.rl.model import DiscreteQNet, FullyConnected
910
from maro.rl.policy import ValueBasedPolicy
1011
from maro.rl.training.algorithms import DQNParams, DQNTrainer
@@ -23,32 +24,62 @@
2324

2425

2526
class MyQNet(DiscreteQNet):
26-
def __init__(self, state_dim: int, action_num: int) -> None:
27+
def __init__(
28+
self,
29+
state_dim: int,
30+
action_num: int,
31+
dueling_param: Optional[Tuple[dict, dict]] = None,
32+
) -> None:
2733
super(MyQNet, self).__init__(state_dim=state_dim, action_num=action_num)
28-
self._fc = FullyConnected(input_dim=state_dim, output_dim=action_num, **q_net_conf)
29-
self._optim = RMSprop(self._fc.parameters(), lr=learning_rate)
34+
35+
self._use_dueling = dueling_param is not None
36+
self._fc = FullyConnected(input_dim=state_dim, output_dim=0 if self._use_dueling else action_num, **q_net_conf)
37+
if self._use_dueling:
38+
q_kwargs, v_kwargs = dueling_param
39+
self._q = FullyConnected(input_dim=self._fc.output_dim, output_dim=action_num, **q_kwargs)
40+
self._v = FullyConnected(input_dim=self._fc.output_dim, output_dim=1, **v_kwargs)
41+
42+
self._optim = RMSprop(self.parameters(), lr=learning_rate)
3043

3144
def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
32-
return self._fc(states)
45+
logits = self._fc(states)
46+
if self._use_dueling:
47+
q = self._q(logits)
48+
v = self._v(logits)
49+
logits = q - q.mean(dim=1, keepdim=True) + v
50+
return logits
3351

3452

3553
def get_dqn_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPolicy:
54+
q_kwargs = {
55+
"hidden_dims": [128],
56+
"activation": torch.nn.LeakyReLU,
57+
"output_activation": torch.nn.LeakyReLU,
58+
"softmax": False,
59+
"batch_norm": True,
60+
"skip_connection": False,
61+
"head": True,
62+
"dropout_p": 0.0,
63+
}
64+
v_kwargs = {
65+
"hidden_dims": [128],
66+
"activation": torch.nn.LeakyReLU,
67+
"output_activation": None,
68+
"softmax": False,
69+
"batch_norm": True,
70+
"skip_connection": False,
71+
"head": True,
72+
"dropout_p": 0.0,
73+
}
74+
3675
return ValueBasedPolicy(
3776
name=name,
38-
q_net=MyQNet(state_dim, action_num),
39-
exploration_strategy=(epsilon_greedy, {"epsilon": 0.4}),
40-
exploration_scheduling_options=[
41-
(
42-
"epsilon",
43-
MultiLinearExplorationScheduler,
44-
{
45-
"splits": [(2, 0.32)],
46-
"initial_value": 0.4,
47-
"last_ep": 5,
48-
"final_value": 0.0,
49-
},
50-
),
51-
],
77+
q_net=MyQNet(
78+
state_dim,
79+
action_num,
80+
dueling_param=(q_kwargs, v_kwargs),
81+
),
82+
explore_strategy=EpsilonGreedy(epsilon=0.4, num_actions=action_num),
5283
warmup=100,
5384
)
5485

@@ -64,6 +95,7 @@ def get_dqn(name: str) -> DQNTrainer:
6495
num_epochs=10,
6596
soft_update_coef=0.1,
6697
double=False,
67-
random_overwrite=False,
98+
alpha=1.0,
99+
beta=1.0,
68100
),
69101
)

examples/cim/rl/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,4 @@
3535

3636
action_num = len(action_shaping_conf["action_space"])
3737

38-
algorithm = "ppo" # ac, ppo, dqn or discrete_maddpg
38+
algorithm = "dqn" # ac, ppo, dqn or discrete_maddpg

examples/vm_scheduling/rl/algorithms/dqn.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch.optim import SGD
77
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
88

9-
from maro.rl.exploration import MultiLinearExplorationScheduler
9+
from maro.rl.exploration import EpsilonGreedy
1010
from maro.rl.model import DiscreteQNet, FullyConnected
1111
from maro.rl.policy import ValueBasedPolicy
1212
from maro.rl.training.algorithms import DQNParams, DQNTrainer
@@ -58,19 +58,7 @@ def get_dqn_policy(state_dim: int, action_num: int, num_features: int, name: str
5858
return ValueBasedPolicy(
5959
name=name,
6060
q_net=MyQNet(state_dim, action_num, num_features),
61-
exploration_strategy=(MaskedEpsGreedy(state_dim, num_features), {"epsilon": 0.4}),
62-
exploration_scheduling_options=[
63-
(
64-
"epsilon",
65-
MultiLinearExplorationScheduler,
66-
{
67-
"splits": [(100, 0.32)],
68-
"initial_value": 0.4,
69-
"last_ep": 400,
70-
"final_value": 0.0,
71-
},
72-
),
73-
],
61+
explore_strategy=EpsilonGreedy(epsilon=0.4, num_actions=action_num),
7462
warmup=100,
7563
)
7664

maro/__misc__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
# Licensed under the MIT license.
33

44

5-
__version__ = "0.3.2a3"
5+
__version__ = "0.3.2a4"
66

77
__data_version__ = "0.2"

maro/rl/exploration/__init__.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT license.
33

4-
from .scheduling import AbsExplorationScheduler, LinearExplorationScheduler, MultiLinearExplorationScheduler
5-
from .strategies import epsilon_greedy, gaussian_noise, uniform_noise
4+
from .strategies import EpsilonGreedy, ExploreStrategy, LinearExploration
65

76
__all__ = [
8-
"AbsExplorationScheduler",
9-
"LinearExplorationScheduler",
10-
"MultiLinearExplorationScheduler",
11-
"epsilon_greedy",
12-
"gaussian_noise",
13-
"uniform_noise",
7+
"ExploreStrategy",
8+
"EpsilonGreedy",
9+
"LinearExploration",
1410
]

maro/rl/exploration/scheduling.py

Lines changed: 0 additions & 127 deletions
This file was deleted.

0 commit comments

Comments
 (0)