66
77import  logging 
88import  random 
9+ from  collections  import  deque 
910from  dataclasses  import  dataclass 
11+ from  operator  import  itemgetter 
1012from  typing  import  Any , Callable 
1113
1214from  monarch .actor  import  endpoint 
1921logger .setLevel (logging .INFO )
2022
2123
24+ @dataclass  
25+ class  BufferEntry :
26+     data : "Episode" 
27+     sample_count : int  =  0 
28+ 
29+ 
30+ def  age_evict (
31+     buffer : deque , policy_version : int , max_samples : int  =  None , max_age : int  =  None 
32+ ) ->  list [int ]:
33+     """Buffer eviction policy, remove old or over-sampled entries""" 
34+     indices  =  []
35+     for  i , entry  in  enumerate (buffer ):
36+         if  max_age  and  policy_version  -  entry .data .policy_version  >  max_age :
37+             continue 
38+         if  max_samples  and  entry .sample_count  >=  max_samples :
39+             continue 
40+         indices .append (i )
41+     return  indices 
42+ 
43+ 
44+ def  random_sample (buffer : deque , sample_size : int , policy_version : int ) ->  list [int ]:
45+     """Buffer random sampling policy""" 
46+     if  sample_size  >  len (buffer ):
47+         return  None 
48+     return  random .sample (range (len (buffer )), k = sample_size )
49+ 
50+ 
2251@dataclass  
2352class  ReplayBuffer (ForgeActor ):
2453    """Simple in-memory replay buffer implementation.""" 
2554
2655    batch_size : int 
27-     max_policy_age : int 
2856    dp_size : int  =  1 
57+     max_policy_age : int  |  None  =  None 
58+     max_buffer_size : int  |  None  =  None 
59+     max_resample_count : int  |  None  =  0 
2960    seed : int  |  None  =  None 
3061    collate : Callable  =  lambda  batch : batch 
31- 
32-     def  __post_init__ (self ):
33-         super ().__init__ ()
62+     eviction_policy : Callable  =  age_evict 
63+     sample_policy : Callable  =  random_sample 
3464
3565    @endpoint  
3666    async  def  setup (self ) ->  None :
37-         self .buffer : list  =  [] 
67+         self .buffer : deque  =  deque ( maxlen = self . max_buffer_size ) 
3868        if  self .seed  is  None :
3969            self .seed  =  random .randint (0 , 2 ** 32 )
4070        random .seed (self .seed )
41-         self .sampler  =  random .sample 
4271
4372    @endpoint  
4473    async  def  add (self , episode : "Episode" ) ->  None :
45-         self .buffer .append (episode )
74+         self .buffer .append (BufferEntry ( episode ) )
4675        record_metric ("buffer/add/count_episodes_added" , 1 , Reduce .SUM )
4776
4877    @endpoint  
4978    @trace ("buffer_perf/sample" , track_memory = False ) 
5079    async  def  sample (
51-         self , curr_policy_version : int ,  batch_size :  int   |   None   =   None 
80+         self , curr_policy_version : int 
5281    ) ->  tuple [tuple [Any , ...], ...] |  None :
5382        """Sample from the replay buffer. 
5483
5584        Args: 
5685            curr_policy_version (int): The current policy version. 
57-             batch_size (int, optional): Number of episodes to sample. If none, defaults to batch size 
58-                 passed in at initialization. 
5986
6087        Returns: 
6188            A list of sampled episodes with shape (dp_size, bsz, ...) or None if there are not enough episodes in the buffer. 
6289        """ 
6390        # Record sample request metric 
6491        record_metric ("buffer/sample/count_sample_requests" , 1 , Reduce .SUM )
6592
66-         bsz  =  batch_size  if  batch_size  is  not None  else  self .batch_size 
67-         total_samples  =  self .dp_size  *  bsz 
93+         total_samples  =  self .dp_size  *  self .batch_size 
6894
69-         # Evict old  episodes 
95+         # Evict episodes 
7096        self ._evict (curr_policy_version )
7197
72-         if  total_samples  >  len (self .buffer ):
73-             return  None 
74- 
75-         # Calculate buffer utilization 
76-         utilization_pct  =  (
77-             (total_samples  /  len (self .buffer )) *  100  if  len (self .buffer ) >  0  else  0 
78-         )
79- 
80-         record_metric (
81-             "buffer/sample/avg_buffer_utilization" ,
82-             len (self .buffer ),
83-             Reduce .MEAN ,
84-         )
85- 
86-         record_metric (
87-             "buffer/sample/avg_buffer_utilization_pct" ,
88-             utilization_pct ,
89-             Reduce .MEAN ,
90-         )
98+         # Calculate metrics 
99+         if  len (self .buffer ) >  0 :
100+             record_metric (
101+                 "buffer/sample/avg_data_utilization" ,
102+                 total_samples  /  len (self .buffer ),
103+                 Reduce .MEAN ,
104+             )
105+         if  self .max_buffer_size :
106+             record_metric (
107+                 "buffer/sample/avg_buffer_utilization" ,
108+                 len (self .buffer ) /  self .max_buffer_size ,
109+                 Reduce .MEAN ,
110+             )
91111
92112        # TODO: prefetch samples in advance 
93-         idx_to_sample  =  self .sampler (range (len (self .buffer )), k = total_samples )
94-         # Pop episodes in descending order to avoid shifting issues 
95-         popped  =  [self .buffer .pop (i ) for  i  in  sorted (idx_to_sample , reverse = True )]
96- 
97-         # Reorder popped episodes to match the original random sample order 
98-         sorted_idxs  =  sorted (idx_to_sample , reverse = True )
99-         idx_to_popped  =  dict (zip (sorted_idxs , popped ))
100-         sampled_episodes  =  [idx_to_popped [i ] for  i  in  idx_to_sample ]
113+         sampled_indices  =  self .sample_policy (
114+             self .buffer , total_samples , curr_policy_version 
115+         )
116+         if  sampled_indices  is  None :
117+             return  None 
118+         sampled_episodes  =  []
119+         for  entry  in  self ._collect (sampled_indices ):
120+             entry .sample_count  +=  1 
121+             sampled_episodes .append (entry .data )
101122
102123        # Reshape into (dp_size, bsz, ...) 
103124        reshaped_episodes  =  [
104-             sampled_episodes [dp_idx  *  bsz  : (dp_idx  +  1 ) *  bsz ]
125+             sampled_episodes [dp_idx  *  self . batch_size  : (dp_idx  +  1 ) *  self . batch_size ]
105126            for  dp_idx  in  range (self .dp_size )
106127        ]
107128
@@ -118,46 +139,69 @@ async def evict(self, curr_policy_version: int) -> None:
118139        """ 
119140        self ._evict (curr_policy_version )
120141
121-     def  _evict (self , curr_policy_version :  int )  ->   None :
142+     def  _evict (self , curr_policy_version ) :
122143        buffer_len_before_evict  =  len (self .buffer )
123-         self .buffer  =  [
124-             trajectory 
125-             for  trajectory  in  self .buffer 
126-             if  (curr_policy_version  -  trajectory .policy_version ) <=  self .max_policy_age 
127-         ]
128-         buffer_len_after_evict  =  len (self .buffer )
144+         indices  =  self .eviction_policy (
145+             self .buffer ,
146+             curr_policy_version ,
147+             self .max_resample_count  +  1 ,
148+             self .max_policy_age ,
149+         )
150+         self .buffer  =  deque (self ._collect (indices ))
129151
130152        # Record evict metrics 
131-         policy_staleness  =  [
132-             curr_policy_version  -  ep .policy_version  for  ep  in  self .buffer 
153+         policy_age  =  [
154+             curr_policy_version  -  ep .data . policy_version  for  ep  in  self .buffer 
133155        ]
134-         if  policy_staleness :
156+         if  policy_age :
135157            record_metric (
136-                 "buffer/evict/avg_policy_staleness " ,
137-                 sum (policy_staleness ) /  len (policy_staleness ),
158+                 "buffer/evict/avg_policy_age " ,
159+                 sum (policy_age ) /  len (policy_age ),
138160                Reduce .MEAN ,
139161            )
140162            record_metric (
141-                 "buffer/evict/max_policy_staleness " ,
142-                 max (policy_staleness ),
163+                 "buffer/evict/max_policy_age " ,
164+                 max (policy_age ),
143165                Reduce .MAX ,
144166            )
145167
146-         # Record eviction metrics 
147-         evicted_count  =  buffer_len_before_evict  -  buffer_len_after_evict 
148-         if  evicted_count  >  0 :
149-             record_metric (
150-                 "buffer/evict/sum_episodes_evicted" , evicted_count , Reduce .SUM 
151-             )
168+         evicted_count  =  buffer_len_before_evict  -  len (self .buffer )
169+         record_metric ("buffer/evict/sum_episodes_evicted" , evicted_count , Reduce .SUM )
152170
153171        logger .debug (
154172            f"maximum policy age: { self .max_policy_age } { curr_policy_version }  
155-             f"{ evicted_count } { buffer_len_after_evict }  
173+             f"{ evicted_count } { len ( self . buffer ) }  
156174        )
157175
176+     def  _collect (self , indices : list [int ]):
177+         """Efficiently traverse deque and collect elements at each requested index""" 
178+         n  =  len (self .buffer )
179+         if  n  ==  0  or  len (indices ) ==  0 :
180+             return  []
181+ 
182+         # Normalize indices and store with their original order 
183+         indexed  =  [(pos , idx  %  n ) for  pos , idx  in  enumerate (indices )]
184+         indexed .sort (key = itemgetter (1 ))
185+ 
186+         result  =  [None ] *  len (indices )
187+         rotations  =  0   # logical current index 
188+         total_rotation  =  0   # total net rotation applied 
189+ 
190+         for  orig_pos , idx  in  indexed :
191+             move  =  idx  -  rotations 
192+             self .buffer .rotate (- move )
193+             total_rotation  +=  move 
194+             rotations  =  idx 
195+             result [orig_pos ] =  self .buffer [0 ]
196+ 
197+         # Restore original deque orientation 
198+         self .buffer .rotate (total_rotation )
199+ 
200+         return  result 
201+ 
158202    @endpoint  
159203    async  def  _getitem (self , idx : int ):
160-         return  self .buffer [idx ]
204+         return  self .buffer [idx ]. data 
161205
162206    @endpoint  
163207    async  def  _numel (self ) ->  int :
0 commit comments