Skip to content

Commit c735a9a

Browse files
authored
feat: multi rate generator (#297)
In order to test scale_in properly, we need a way to dynamically increase or decrease sending rate. For that, we added multi rate generator schedule. The key is the replay on which the rate update should happen and the value is the new rate. Added controller configuration as example on how this can be used.
1 parent 3362fd1 commit c735a9a

File tree

4 files changed

+155
-7
lines changed

4 files changed

+155
-7
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
---
2+
# api_port: 8080
3+
# ctrl_port: 31310
4+
autoscale: true
5+
deploy_policy: static # options: even, packing, random (default), static
6+
# if autoscale is enabled, deploy_policy is static
7+
8+
# a directory path to job deployment templates/plans
9+
# pick a directory where plan files are located
10+
job_plans: ~/projects/infscale/examples/configs/plans
11+
12+
reqgen:
13+
sort: multirate_exponential
14+
params:
15+
in_memory: true
16+
replay: 20
17+
rate: 1600.0
18+
schedule:
19+
- replay_index: 3
20+
rate: 1800.0
21+
- replay_index: 10
22+
rate: 300.0

infscale/common/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ def __init__(self, err_msg: str):
4343
super().__init__(err_msg)
4444

4545

46+
class InvalidGenConfig(InfScaleException):
47+
"""Exception for invalid generator configuration."""
48+
49+
def __init__(self, err_msg: str):
50+
"""Initialize InvalidGenConfig exception instance."""
51+
super().__init__(err_msg)
52+
53+
4654
class InsufficientResources(InfScaleException):
4755
"""Exception for insufficient agent resources."""
4856

infscale/configs/controller.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,8 @@
2020
from dataclasses import dataclass, field
2121
from enum import Enum
2222

23-
from infscale.common.constants import (
24-
APISERVER_PORT,
25-
CONTROLLER_PORT,
26-
)
23+
from infscale.common.constants import APISERVER_PORT, CONTROLLER_PORT
24+
from infscale.common.exceptions import InvalidGenConfig
2725

2826

2927
class DeploymentPolicyEnum(Enum):
@@ -47,6 +45,7 @@ class ReqGenEnum(str, Enum):
4745

4846
DEFAULT = "default"
4947
EXP = "exponential"
48+
MULTIRATE_EXP = "multirate_exponential"
5049

5150

5251
@dataclass
@@ -67,7 +66,20 @@ class ExponentialParams(DefaultParams):
6766
rate: float = 1.0 # rate is per-second
6867

6968

70-
GenParams = DefaultParams | ExponentialParams
69+
@dataclass
70+
class RateScheduleItem:
71+
replay_index: int
72+
rate: float
73+
74+
75+
@dataclass
76+
class MultiRateExponentialParams(ExponentialParams):
77+
"""Multi-rate exponential generator parameters."""
78+
79+
schedule: list[RateScheduleItem] = field(default_factory=list)
80+
81+
82+
GenParams = DefaultParams | ExponentialParams | MultiRateExponentialParams
7183

7284

7385
@dataclass
@@ -94,6 +106,32 @@ def __post_init__(self):
94106
case ReqGenEnum.EXP:
95107
self.params = ExponentialParams(**self.params)
96108

109+
case ReqGenEnum.MULTIRATE_EXP:
110+
self.params = MultiRateExponentialParams(**self.params)
111+
112+
if self.params.replay is None:
113+
raise InvalidGenConfig(f"Replay param is required.")
114+
115+
if len(self.params.schedule) > 0:
116+
self.params.schedule = [
117+
RateScheduleItem(**item) if isinstance(item, dict) else item
118+
for item in self.params.schedule
119+
]
120+
replay_indexes = [
121+
item.replay_index for item in self.params.schedule
122+
]
123+
min_key = min(replay_indexes)
124+
max_key = max(replay_indexes)
125+
126+
if min_key <= 0:
127+
raise InvalidGenConfig(
128+
f"invalid schedule: iteration {min_key} must be positive"
129+
)
130+
if max_key > self.params.replay:
131+
msg = "invalid schedule:"
132+
msg += f" iteration {max_key} exceeds replay limit {self.params.replay}"
133+
raise InvalidGenConfig(msg)
134+
97135

98136
@dataclass
99137
class CtrlConfig:

infscale/request/generator.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,21 @@
1717
"""generator.py."""
1818

1919
import asyncio
20+
import os
2021
from abc import ABC, abstractmethod
2122

2223
import numpy as np
2324
from torch import Tensor
2425

25-
from infscale.configs.controller import GenParams, ReqGenEnum
26+
from infscale import get_logger
27+
from infscale.configs.controller import GenParams, RateScheduleItem, ReqGenEnum
2628
from infscale.execution.metrics_collector import MetricsCollector
2729
from infscale.module.dataset import HuggingFaceDataset
2830

2931

32+
logger = None
33+
34+
3035
class Generator(ABC):
3136
"""Abstact Generator class."""
3237

@@ -44,6 +49,9 @@ def initialize(
4449
self._mc = mc
4550
self._seqno = 0
4651

52+
global logger
53+
logger = get_logger(f"{os.getpid()}")
54+
4755
@abstractmethod
4856
async def get(self) -> list[Tensor | None]:
4957
"""Return generated requests as batch."""
@@ -100,9 +108,12 @@ async def _generate(self) -> None:
100108
self._mc.update(self._seqno)
101109
self._seqno += 1
102110

103-
iat = np.random.exponential(scale=1 / self._batch_rate)
111+
iat = self._compute_iat()
104112
await asyncio.sleep(iat)
105113

114+
def _compute_iat(self):
115+
return np.random.exponential(scale=1 / self._batch_rate)
116+
106117
async def get(self) -> list[Tensor | None]:
107118
"""Return one batch of requests.
108119
@@ -122,6 +133,74 @@ async def get(self) -> list[Tensor | None]:
122133
return batches
123134

124135

136+
class MultiRateExponentialGenerator(ExponentialGenerator):
137+
"""Exponential generator with replay-dependent rate schedule."""
138+
139+
def initialize(
140+
self,
141+
dataset,
142+
params,
143+
batch_size,
144+
mc,
145+
) -> None:
146+
assert params is not None
147+
# intentionally bypassing super().initialize
148+
# for properly setting up queue and event and to avoid duplicating
149+
# asyncio task creation for _generate method
150+
Generator.initialize(self, dataset, params, batch_size, mc)
151+
152+
self.range_list = self._prepare_schedule(
153+
self._params.rate, self._params.schedule, self._params.replay
154+
)
155+
156+
self._range_index = 0
157+
rate = self.range_list[0][2]
158+
self._batch_rate = rate / self._batch_size
159+
160+
self._queue = asyncio.Queue()
161+
self._gen_evt = asyncio.Event()
162+
_ = asyncio.create_task(self._generate())
163+
164+
msg = f"generator initialized with rate={rate}"
165+
msg += f" replay rate update schedule={self._params.schedule}"
166+
logger.info(msg)
167+
168+
def _prepare_schedule(
169+
self, base_rate: float, schedule: list[RateScheduleItem], max_replay: int
170+
) -> list[tuple[int, int, float]]:
171+
"""Convert replay-based schedule into continuous replay ranges."""
172+
schedule_sorted = sorted(schedule, key=lambda s: s.replay_index)
173+
174+
rate_schedule_ranges = []
175+
prev_replay = 0
176+
prev_rate = base_rate
177+
178+
for item in schedule_sorted:
179+
# range [prev_replay, item.replay_index - 1] uses prev_rate
180+
rate_schedule_ranges.append((prev_replay, item.replay_index - 1, prev_rate))
181+
prev_replay = item.replay_index
182+
prev_rate = item.rate
183+
184+
# last range goes until max_replay
185+
rate_schedule_ranges.append((prev_replay, max_replay, prev_rate))
186+
return rate_schedule_ranges
187+
188+
def _compute_iat(self):
189+
current_replay = self._params.replay - self._dataset._replay
190+
range_info = self.range_list[self._range_index]
191+
192+
if not range_info[0] <= current_replay <= range_info[1]:
193+
self._range_index += 1
194+
195+
range_info = self.range_list[self._range_index]
196+
rate = range_info[2]
197+
self._batch_rate = rate / self._batch_size
198+
199+
logger.info(f"sending rate updated to {rate}")
200+
201+
return np.random.exponential(scale=1 / self._batch_rate)
202+
203+
125204
class GeneratorFactory:
126205
"""Request generator factory class."""
127206

@@ -131,6 +210,7 @@ def get(sort: ReqGenEnum) -> Generator:
131210
generators = {
132211
ReqGenEnum.DEFAULT: DefaultGenerator(),
133212
ReqGenEnum.EXP: ExponentialGenerator(),
213+
ReqGenEnum.MULTIRATE_EXP: MultiRateExponentialGenerator(),
134214
}
135215

136216
return generators[sort]

0 commit comments

Comments
 (0)