Skip to content

Commit 6162b81

Browse files
committed
fix branch conflict
1 parent 2cee01a commit 6162b81

File tree

6 files changed

+40
-28
lines changed

6 files changed

+40
-28
lines changed

ding/framework/message_queue/perfs/perf_nng.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
REPEAT = 10
2121
LENGTH = 5
2222
EXP_NUMS = 2
23-
UNIT_SIZE_LIST = [64, 1024, 64 * 1024, 512 * 1024, 2 * 1024 * 1024]
23+
UNIT_SIZE_LIST = [64, 512, 1 * 1024, 4 * 1024, 64 * 1024]
2424

2525

2626
@click.command(context_settings=dict(help_option_names=['-h', '--help']))

ding/framework/message_queue/perfs/perf_shm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
LENGTH = 5
1515
REPEAT = 10
16-
UNIT_SIZE_LIST = [64, 1024, 64 * 1024, 512 * 1024, 2 * 1024 * 1024]
16+
UNIT_SIZE_LIST = [64, 512, 1 * 1024, 4 * 1024, 64 * 1024]
1717
logging.getLogger().setLevel(logging.INFO)
1818

1919

ding/framework/message_queue/perfs/perf_torchrpc_nccl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
LENGTH = 5
1919
REPEAT = 2
2020
MAX_EXP_NUMS = 10
21-
UNIT_SIZE_LIST = [64, 1024, 64 * 1024, 512 * 1024, 2 * 1024 * 1024]
21+
UNIT_SIZE_LIST = [64, 512, 1 * 1024, 4 * 1024, 64 * 1024]
2222

2323

2424
@dataclass

ding/framework/middleware/collector.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,12 @@ def __init__(self, cfg: EasyDict, policy, env: BaseEnvManager, random_collect_si
3838
self.policy = policy
3939
self.random_collect_size = random_collect_size
4040
self._transitions = TransitionList(self.env.env_num)
41+
if hasattr(cfg, "env") and hasattr(cfg.env, "manager"):
42+
use_cuda_shared_memory = cfg.env.manager.cuda_shared_memory
43+
else:
44+
use_cuda_shared_memory = False
4145
self._inferencer = task.wrap(inferencer(cfg.seed, policy, env))
42-
self._rolloutor = task.wrap(rolloutor(policy, env, self._transitions))
46+
self._rolloutor = task.wrap(rolloutor(policy, env, self._transitions, use_cuda_shared_memory))
4347

4448
def __call__(self, ctx: "OnlineRLContext") -> None:
4549
"""

ding/framework/middleware/distributer.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,7 @@
1313

1414
class ContextExchanger:
1515

16-
def __init__(
17-
self,
18-
skip_n_iter: int = 1,
19-
storage_loader: Optional[StorageLoader] = None,
20-
) -> None:
16+
def __init__(self, skip_n_iter: int = 1, storage_loader: Optional[StorageLoader] = None) -> None:
2117
"""
2218
Overview:
2319
Exchange context between processes,
@@ -41,9 +37,8 @@ def __init__(
4137
self._storage_loader = storage_loader
4238

4339
# Both nng and torchrpc use background threads to trigger the receiver's recv action,
44-
# there is a race condition between sender and sender, and between senders and receiver.
40+
# there is a race condition between the listen thread and the polling thread.
4541
self._put_lock = LockContext(LockContextType.THREAD_LOCK)
46-
self._recv_ready = False
4742
self._bypass_eventloop = task.router.mq_type == MQType.RPC
4843

4944
for role in task.role: # Only subscribe to other roles
@@ -101,7 +96,6 @@ def callback(payload: Dict):
10196
getattr(self, fn_name)(item)
10297
else:
10398
logging.warning("Receive unexpected key ({}) in context exchanger".format(key))
104-
self._recv_ready = True
10599

106100
if isinstance(payload, Storage):
107101
assert self._storage_loader is not None, "Storage loader is not defined when data is a storage object."
@@ -126,19 +120,27 @@ def fetch(self, ctx: "Context") -> Dict[str, Any]:
126120
return payload
127121

128122
def merge(self, ctx: "Context"):
129-
123+
# Dict's assignment is not an atomic operation, even if len(self._state)
124+
# is not 0, the value corresponding to the key maybe empty.
125+
ready = 0
130126
if task.has_role(task.role.LEARNER):
131127
# Learner should always wait for trajs.
132128
# TODO: Automaticlly wait based on properties, not roles.
133-
while self._recv_ready is False:
134-
sleep(0.01)
129+
while ready == 0:
130+
with self._put_lock:
131+
ready = len(self._state)
132+
if ready == 0:
133+
sleep(0.01)
135134
elif ctx.total_step >= self._skip_n_iter:
136135
start = time()
137-
while self._recv_ready is False:
138-
if time() - start > 60:
139-
logging.warning("Timeout when waiting for new context! Node id: {}".format(task.router.node_id))
140-
break
141-
sleep(0.01)
136+
while ready == 0:
137+
with self._put_lock:
138+
ready = len(self._state)
139+
if ready == 0:
140+
if time() - start > 60:
141+
logging.warning("Timeout when waiting for new context! Node id: {}".format(task.router.node_id))
142+
break
143+
sleep(0.01)
142144

143145
with self._put_lock:
144146
for k, v in self._state.items():
@@ -148,7 +150,6 @@ def merge(self, ctx: "Context"):
148150
else:
149151
setattr(ctx, k, v)
150152
self._state = {}
151-
self._recv_ready = False
152153

153154
# Handle each attibute of context
154155
def _put_trajectories(self, traj: List[Any]):
@@ -173,14 +174,14 @@ def _fetch_episodes(self, episodes: List[Any]):
173174
if task.has_role(task.role.COLLECTOR):
174175
return episodes
175176

176-
def _put_trajectory_end_idx(self, trajectory_end_idx: List[int]):
177+
def _put_trajectory_end_idx(self, trajectory_end_idx: List[str]):
177178
if not task.has_role(task.role.LEARNER):
178179
return
179180
if "trajectory_end_idx" not in self._state:
180181
self._state["trajectory_end_idx"] = []
181182
self._state["trajectory_end_idx"].extend(trajectory_end_idx)
182183

183-
def _fetch_trajectory_end_idx(self, trajectory_end_idx: List[int]):
184+
def _fetch_trajectory_end_idx(self, trajectory_end_idx: List[str]):
184185
if task.has_role(task.role.COLLECTOR):
185186
return trajectory_end_idx
186187

@@ -202,6 +203,12 @@ def _put_env_episode(self, increment_env_episode: int):
202203
self._state['increment_env_episode'] = 0
203204
self._state["increment_env_episode"] += increment_env_episode
204205

206+
def _fetch_env_episode(self, env_episode: int):
207+
if task.has_role(task.role.COLLECTOR):
208+
increment_env_episode = env_episode - self._local_state['env_episode']
209+
self._local_state['env_episode'] = env_episode
210+
return increment_env_episode
211+
205212
def _put_train_iter(self, train_iter: int):
206213
if not task.has_role(task.role.LEARNER):
207214
self._state["train_iter"] = train_iter

ding/framework/middleware/functional/collector.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,12 @@ def _inference(ctx: "OnlineRLContext"):
8484
return _inference
8585

8686

87-
def rolloutor(policy: Policy, env: BaseEnvManager, transitions: TransitionList) -> Callable:
87+
def rolloutor(
88+
policy: Policy,
89+
env: BaseEnvManager,
90+
transitions: TransitionList,
91+
use_cuda_shared_memory: bool = False
92+
) -> Callable:
8893
"""
8994
Overview:
9095
The middleware that executes the transition process in the env.
@@ -99,10 +104,6 @@ def rolloutor(policy: Policy, env: BaseEnvManager, transitions: TransitionList)
99104

100105
env_episode_id = [_ for _ in range(env.env_num)]
101106
current_id = env.env_num
102-
use_cuda_shared_memory = False
103-
104-
if hasattr(cfg, "env") and hasattr(cfg.env, "manager"):
105-
use_cuda_shared_memory = cfg.env.manager.cuda_shared_memory
106107

107108
def _rollout(ctx: "OnlineRLContext"):
108109
"""

0 commit comments

Comments
 (0)