1717"""generator.py."""
1818
1919import asyncio
20+ import os
2021from abc import ABC , abstractmethod
2122
2223import numpy as np
2324from torch import Tensor
2425
25- from infscale .configs .controller import GenParams , ReqGenEnum
26+ from infscale import get_logger
27+ from infscale .configs .controller import GenParams , RateScheduleItem , ReqGenEnum
2628from infscale .execution .metrics_collector import MetricsCollector
2729from infscale .module .dataset import HuggingFaceDataset
2830
2931
32+ logger = None
33+
34+
3035class Generator (ABC ):
3136 """Abstact Generator class."""
3237
@@ -44,6 +49,9 @@ def initialize(
4449 self ._mc = mc
4550 self ._seqno = 0
4651
52+ global logger
53+ logger = get_logger (f"{ os .getpid ()} " )
54+
4755 @abstractmethod
4856 async def get (self ) -> list [Tensor | None ]:
4957 """Return generated requests as batch."""
@@ -100,9 +108,12 @@ async def _generate(self) -> None:
100108 self ._mc .update (self ._seqno )
101109 self ._seqno += 1
102110
103- iat = np . random . exponential ( scale = 1 / self ._batch_rate )
111+ iat = self ._compute_iat ( )
104112 await asyncio .sleep (iat )
105113
114+ def _compute_iat (self ):
115+ return np .random .exponential (scale = 1 / self ._batch_rate )
116+
106117 async def get (self ) -> list [Tensor | None ]:
107118 """Return one batch of requests.
108119
@@ -122,6 +133,74 @@ async def get(self) -> list[Tensor | None]:
122133 return batches
123134
124135
136+ class MultiRateExponentialGenerator (ExponentialGenerator ):
137+ """Exponential generator with replay-dependent rate schedule."""
138+
139+ def initialize (
140+ self ,
141+ dataset ,
142+ params ,
143+ batch_size ,
144+ mc ,
145+ ) -> None :
146+ assert params is not None
147+ # intentionally bypassing super().initialize
148+ # for properly setting up queue and event and to avoid duplicating
149+ # asyncio task creation for _generate method
150+ Generator .initialize (self , dataset , params , batch_size , mc )
151+
152+ self .range_list = self ._prepare_schedule (
153+ self ._params .rate , self ._params .schedule , self ._params .replay
154+ )
155+
156+ self ._range_index = 0
157+ rate = self .range_list [0 ][2 ]
158+ self ._batch_rate = rate / self ._batch_size
159+
160+ self ._queue = asyncio .Queue ()
161+ self ._gen_evt = asyncio .Event ()
162+ _ = asyncio .create_task (self ._generate ())
163+
164+ msg = f"generator initialized with rate={ rate } "
165+ msg += f" replay rate update schedule={ self ._params .schedule } "
166+ logger .info (msg )
167+
168+ def _prepare_schedule (
169+ self , base_rate : float , schedule : list [RateScheduleItem ], max_replay : int
170+ ) -> list [tuple [int , int , float ]]:
171+ """Convert replay-based schedule into continuous replay ranges."""
172+ schedule_sorted = sorted (schedule , key = lambda s : s .replay_index )
173+
174+ rate_schedule_ranges = []
175+ prev_replay = 0
176+ prev_rate = base_rate
177+
178+ for item in schedule_sorted :
179+ # range [prev_replay, item.replay_index - 1] uses prev_rate
180+ rate_schedule_ranges .append ((prev_replay , item .replay_index - 1 , prev_rate ))
181+ prev_replay = item .replay_index
182+ prev_rate = item .rate
183+
184+ # last range goes until max_replay
185+ rate_schedule_ranges .append ((prev_replay , max_replay , prev_rate ))
186+ return rate_schedule_ranges
187+
188+ def _compute_iat (self ):
189+ current_replay = self ._params .replay - self ._dataset ._replay
190+ range_info = self .range_list [self ._range_index ]
191+
192+ if not range_info [0 ] <= current_replay <= range_info [1 ]:
193+ self ._range_index += 1
194+
195+ range_info = self .range_list [self ._range_index ]
196+ rate = range_info [2 ]
197+ self ._batch_rate = rate / self ._batch_size
198+
199+ logger .info (f"sending rate updated to { rate } " )
200+
201+ return np .random .exponential (scale = 1 / self ._batch_rate )
202+
203+
125204class GeneratorFactory :
126205 """Request generator factory class."""
127206
@@ -131,6 +210,7 @@ def get(sort: ReqGenEnum) -> Generator:
131210 generators = {
132211 ReqGenEnum .DEFAULT : DefaultGenerator (),
133212 ReqGenEnum .EXP : ExponentialGenerator (),
213+ ReqGenEnum .MULTIRATE_EXP : MultiRateExponentialGenerator (),
134214 }
135215
136216 return generators [sort ]
0 commit comments