13
13
14
14
class ContextExchanger :
15
15
16
- def __init__ (
17
- self ,
18
- skip_n_iter : int = 1 ,
19
- storage_loader : Optional [StorageLoader ] = None ,
20
- ) -> None :
16
+ def __init__ (self , skip_n_iter : int = 1 , storage_loader : Optional [StorageLoader ] = None ) -> None :
21
17
"""
22
18
Overview:
23
19
Exchange context between processes,
@@ -41,9 +37,8 @@ def __init__(
41
37
self ._storage_loader = storage_loader
42
38
43
39
# Both nng and torchrpc use background threads to trigger the receiver's recv action,
44
- # there is a race condition between sender and sender, and between senders and receiver .
40
+ # there is a race condition between the listen thread and the polling thread .
45
41
self ._put_lock = LockContext (LockContextType .THREAD_LOCK )
46
- self ._recv_ready = False
47
42
self ._bypass_eventloop = task .router .mq_type == MQType .RPC
48
43
49
44
for role in task .role : # Only subscribe to other roles
@@ -101,7 +96,6 @@ def callback(payload: Dict):
101
96
getattr (self , fn_name )(item )
102
97
else :
103
98
logging .warning ("Receive unexpected key ({}) in context exchanger" .format (key ))
104
- self ._recv_ready = True
105
99
106
100
if isinstance (payload , Storage ):
107
101
assert self ._storage_loader is not None , "Storage loader is not defined when data is a storage object."
@@ -126,19 +120,27 @@ def fetch(self, ctx: "Context") -> Dict[str, Any]:
126
120
return payload
127
121
128
122
def merge (self , ctx : "Context" ):
129
-
123
+ # Dict's assignment is not an atomic operation, even if len(self._state)
124
+ # is not 0, the value corresponding to the key maybe empty.
125
+ ready = 0
130
126
if task .has_role (task .role .LEARNER ):
131
127
# Learner should always wait for trajs.
132
128
# TODO: Automaticlly wait based on properties, not roles.
133
- while self ._recv_ready is False :
134
- sleep (0.01 )
129
+ while ready == 0 :
130
+ with self ._put_lock :
131
+ ready = len (self ._state )
132
+ if ready == 0 :
133
+ sleep (0.01 )
135
134
elif ctx .total_step >= self ._skip_n_iter :
136
135
start = time ()
137
- while self ._recv_ready is False :
138
- if time () - start > 60 :
139
- logging .warning ("Timeout when waiting for new context! Node id: {}" .format (task .router .node_id ))
140
- break
141
- sleep (0.01 )
136
+ while ready == 0 :
137
+ with self ._put_lock :
138
+ ready = len (self ._state )
139
+ if ready == 0 :
140
+ if time () - start > 60 :
141
+ logging .warning ("Timeout when waiting for new context! Node id: {}" .format (task .router .node_id ))
142
+ break
143
+ sleep (0.01 )
142
144
143
145
with self ._put_lock :
144
146
for k , v in self ._state .items ():
@@ -148,7 +150,6 @@ def merge(self, ctx: "Context"):
148
150
else :
149
151
setattr (ctx , k , v )
150
152
self ._state = {}
151
- self ._recv_ready = False
152
153
153
154
# Handle each attibute of context
154
155
def _put_trajectories (self , traj : List [Any ]):
@@ -173,14 +174,14 @@ def _fetch_episodes(self, episodes: List[Any]):
173
174
if task .has_role (task .role .COLLECTOR ):
174
175
return episodes
175
176
176
- def _put_trajectory_end_idx (self , trajectory_end_idx : List [int ]):
177
+ def _put_trajectory_end_idx (self , trajectory_end_idx : List [str ]):
177
178
if not task .has_role (task .role .LEARNER ):
178
179
return
179
180
if "trajectory_end_idx" not in self ._state :
180
181
self ._state ["trajectory_end_idx" ] = []
181
182
self ._state ["trajectory_end_idx" ].extend (trajectory_end_idx )
182
183
183
- def _fetch_trajectory_end_idx (self , trajectory_end_idx : List [int ]):
184
+ def _fetch_trajectory_end_idx (self , trajectory_end_idx : List [str ]):
184
185
if task .has_role (task .role .COLLECTOR ):
185
186
return trajectory_end_idx
186
187
@@ -202,6 +203,12 @@ def _put_env_episode(self, increment_env_episode: int):
202
203
self ._state ['increment_env_episode' ] = 0
203
204
self ._state ["increment_env_episode" ] += increment_env_episode
204
205
206
+ def _fetch_env_episode (self , env_episode : int ):
207
+ if task .has_role (task .role .COLLECTOR ):
208
+ increment_env_episode = env_episode - self ._local_state ['env_episode' ]
209
+ self ._local_state ['env_episode' ] = env_episode
210
+ return increment_env_episode
211
+
205
212
def _put_train_iter (self , train_iter : int ):
206
213
if not task .has_role (task .role .LEARNER ):
207
214
self ._state ["train_iter" ] = train_iter
0 commit comments