Skip to content

Commit f5f4fc6

Browse files
committed
updates
1 parent 9661ec8 commit f5f4fc6

File tree

4 files changed

+61
-44
lines changed

4 files changed

+61
-44
lines changed

src/inference_endpoint/endpoint_client/cpu_affinity.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ def wrapper(*args, **kwargs):
6060
# Default physical cores for LoadGen (main process):
6161
# - Session thread (scheduler, busy-wait timing)
6262
# - Event loop thread (uvloop, response handling)
63-
# TODO(vir): use +2 additional if avialable (since zmq-io-threads=4)
64-
DEFAULT_LOADGEN_CORES = 4
63+
DEFAULT_LOADGEN_CORES = 5
6564

6665

6766
# =============================================================================

src/inference_endpoint/endpoint_client/http.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class _SocketConfig:
5353
# client kernel sends probe, server's kernel ACKs - no application overhead
5454
#
5555
# TODO(vir): verify impact on failure-detection, we want to fail fast
56-
# detection time: KEEPIDLE + (KEEPCNT × KEEPINTVL) = 1 + 5×1 = 5s
56+
# detection time: KEEPIDLE + (KEEPCNT × KEEPINTVL) = 1 + 5×1 = 6s
5757
SO_KEEPALIVE: int = 1 # Enable keepalive at socket level
5858
TCP_KEEPIDLE: int = 1 # Probe after 1s idle
5959
TCP_KEEPCNT: int = 5 # 5 failed probes = dead

src/inference_endpoint/endpoint_client/worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,8 @@ async def run(self) -> None:
201201

202202
# Create connection pool
203203
# Naively divide max connections among workers
204-
connections_per_worker = (
205-
self.http_config.max_connections // self.http_config.num_workers
204+
connections_per_worker = max(
205+
1, self.http_config.max_connections // self.http_config.num_workers
206206
)
207207
self._pool = ConnectionPool(
208208
host=self._host,

src/inference_endpoint/testing/variable_throughput_server.py

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77
instantly for roofline testing), this server models realistic LLM inference:
88
99
* **Variable output lengths** — lognormal distribution, configurable mean + spread
10-
* **Per-request response rate** — lognormal distribution (responses/sec per request)
10+
* **Per-worker response rate** — token bucket rate limiter per worker process
1111
* **First-chunk latency (TTFT)** — lognormal delay before first data
12-
* **Per-chunk latency jitter** — lognormal inter-chunk delays in streaming mode
12+
* **Per-chunk latency jitter** — lognormal inter-token delays in streaming mode
1313
1414
Two mutually exclusive timing modes:
1515
1616
* **Response-rate mode** (``--response-rate-mean``): controls total response time
17-
per request. TPOT is derived from ``(1/rate - TTFT) / num_chunks``.
18-
* **Inter-chunk mode** (``--inter-chunk-latency``): controls per-chunk delay
19-
directly. Total response time = TTFT + num_chunks × TPOT.
17+
per request. TPOT is derived from ``(1/rate - TTFT) / (num_chunks - 1)``.
18+
* **Inter-token mode** (``--inter-token-latency``): controls per-token delay
19+
(TPOT) directly. Actual inter-SSE-event delay = TPOT × stream_interval.
2020
2121
Usage::
2222
@@ -28,10 +28,10 @@
2828
--output-len-mean 1000 --output-len-spread 0.4 \\
2929
--response-rate-mean 10000 --response-rate-spread 2.0
3030
31-
# Streaming with inter-chunk latency (ms) + TTFT (s)
31+
# Streaming with inter-token latency (ms) + TTFT (s)
3232
python -m inference_endpoint.testing.variable_throughput_server --stream --stats \\
3333
--stream-interval 2 \\
34-
--inter-chunk-latency 20 --inter-chunk-spread 0.05 \\
34+
--inter-token-latency 20 --inter-token-spread 0.05 \\
3535
--first-chunk-latency 0.1 --first-chunk-spread 0.02
3636
3737
# Streaming with response-rate + TTFT
@@ -109,12 +109,14 @@ def _lognormal_params(mean: float, cv: float) -> tuple[float, float]:
109109

110110

111111
class _TokenBucket:
112-
"""Global rate limiter: enforces N responses/sec across all concurrent requests."""
112+
"""Per-worker rate limiter: enforces N responses/sec within a single worker."""
113113

114114
__slots__ = ("_interval", "_available_at")
115115

116116
def __init__(self, rate: float):
117-
self._interval = 1.0 / rate if rate > 0 else 0.0
117+
if rate <= 0:
118+
raise ValueError(f"rate must be > 0, got {rate}")
119+
self._interval = 1.0 / rate
118120
self._available_at = 0.0
119121

120122
async def acquire(self) -> None:
@@ -139,7 +141,7 @@ class VariableResponseProtocol(asyncio.Protocol):
139141
140142
Two timing modes:
141143
- **rate**: Global token bucket controls overall throughput (responses/sec).
142-
- **icl**: Per-request inter-chunk latency, no global rate limit.
144+
- **icl**: Per-request inter-token latency, no global rate limit.
143145
144146
Because ``data_received`` is synchronous, async work is dispatched via
145147
``loop.create_task``.
@@ -171,7 +173,7 @@ class VariableResponseProtocol(asyncio.Protocol):
171173
"_rate_mu",
172174
"_rate_sigma",
173175
"_rate_mean",
174-
# Inter-chunk-latency params (mode="icl")
176+
# Inter-token-latency params (mode="icl")
175177
"_icl_mu",
176178
"_icl_sigma",
177179
"_icl_mean",
@@ -277,7 +279,6 @@ async def _handle_request(self, osl: int):
277279

278280
try:
279281
rng = self._rng
280-
num_chunks = math.ceil(osl / self._stream_interval) if self._stream else 1
281282

282283
# Rate mode: global token bucket controls overall throughput
283284
if self._mode == "rate":
@@ -292,8 +293,10 @@ async def _handle_request(self, osl: int):
292293
rng, self._rate_mu, self._rate_sigma, self._rate_mean
293294
)
294295
total_time = 1.0 / rate if rate > 0 else 0.0
295-
if self._stream and num_chunks > 1:
296-
tpot = max(0.0, (total_time - ttft) / num_chunks)
296+
if self._stream and osl > 1:
297+
# Derive per-token TPOT: remaining time spread across output tokens.
298+
# Streaming loop scales by tokens_per_chunk automatically.
299+
tpot = max(0.0, (total_time - ttft) / (osl - 1))
297300
else:
298301
# Non-streaming or single chunk: total_time is the full delay
299302
ttft = total_time
@@ -332,13 +335,18 @@ async def _handle_non_streaming(self, transport, osl: int, delay: float):
332335
_counter_add(_byte_counter, len(response))
333336

334337
async def _handle_streaming(self, transport, osl: int, ttft: float, tpot: float):
335-
"""Stream chunks with TTFT + per-chunk TPOT delays."""
338+
"""Stream chunks with TTFT + per-token TPOT delays.
339+
340+
tpot is the simulated per-token generation time (seconds).
341+
stream_interval is chars per SSE event (≈ tokens).
342+
Actual inter-event delay = tpot × stream_interval.
343+
"""
336344
interval = self._stream_interval
337345
num_events = math.ceil(osl / interval)
338346
created = int(time.time())
339347
model = self._model
340348

341-
# Pre-compile all chunk bytes upfront to keep the hot loop free of encoding.
349+
# Pre-compile all chunk bytes upfront.
342350
chunks: list[bytes] = []
343351
chars_left = osl
344352
for _ in range(num_events):
@@ -377,9 +385,9 @@ async def _handle_streaming(self, transport, osl: int, ttft: float, tpot: float)
377385
if transport.is_closing():
378386
return
379387

380-
# TPOT delay between chunks (skip first — TTFT already applied)
388+
# Delay = tpot × chars_per_event (chars ≈ tokens). Skip first — TTFT already applied.
381389
if i > 0 and tpot > 0:
382-
target += tpot
390+
target += tpot * interval
383391
wait = target - loop.time()
384392
if wait > 0:
385393
await asyncio.sleep(wait)
@@ -505,7 +513,7 @@ class VariableResponseServer:
505513
Timing is per-request with two mutually exclusive modes:
506514
507515
* **Response-rate mode**: each request samples its own rate, TPOT is derived.
508-
* **Inter-chunk mode**: each request samples its own TPOT directly.
516+
* **Inter-token mode**: each request samples its own TPOT directly.
509517
510518
Both modes support first-chunk latency (TTFT) with optional spread.
511519
@@ -518,8 +526,8 @@ class VariableResponseServer:
518526
output_len_max: Maximum output sequence length (chars). None = 8 * mean.
519527
response_rate_mean: Per-request response rate mean (responses/sec). 0 = no rate mode.
520528
response_rate_spread: CoV for per-request response rate.
521-
inter_chunk_latency: Per-chunk delay mean in milliseconds. 0 = no ICL mode.
522-
inter_chunk_spread: CoV for per-chunk delay.
529+
inter_token_latency: Per-token delay (TPOT) mean in milliseconds. 0 = no ICL mode.
530+
inter_token_spread: CoV for per-chunk delay.
523531
first_chunk_latency: Mean TTFT in seconds. 0 = no TTFT delay.
524532
first_chunk_spread: CoV for TTFT.
525533
stream: SSE streaming mode.
@@ -541,8 +549,8 @@ def __init__(
541549
output_len_max: int | None = None,
542550
response_rate_mean: float = 0.0,
543551
response_rate_spread: float = 0.0,
544-
inter_chunk_latency: float = 0.0,
545-
inter_chunk_spread: float = 0.0,
552+
inter_token_latency: float = 0.0,
553+
inter_token_spread: float = 0.0,
546554
first_chunk_latency: float = 0.0,
547555
first_chunk_spread: float = 0.2,
548556
stream: bool = False,
@@ -554,10 +562,20 @@ def __init__(
554562
quiet: bool = False,
555563
):
556564
# Validate mutual exclusivity
557-
if response_rate_mean > 0 and inter_chunk_latency > 0:
565+
if response_rate_mean > 0 and inter_token_latency > 0:
566+
raise ValueError(
567+
"response_rate_mean and inter_token_latency are mutually exclusive. "
568+
"Use response-rate mode OR inter-token-latency mode, not both."
569+
)
570+
if num_workers <= 0:
571+
raise ValueError(f"num_workers must be > 0, got {num_workers}")
572+
if stream and stream_interval <= 0:
573+
raise ValueError(
574+
f"stream_interval must be > 0 in streaming mode, got {stream_interval}"
575+
)
576+
if response_rate_mean < 0:
558577
raise ValueError(
559-
"response_rate_mean and inter_chunk_latency are mutually exclusive. "
560-
"Use response-rate mode OR inter-chunk-latency mode, not both."
578+
f"response_rate_mean must be >= 0, got {response_rate_mean}"
561579
)
562580

563581
self.host, self.port, self.num_workers = host, port, num_workers
@@ -585,7 +603,7 @@ def __init__(
585603
# Determine timing mode
586604
if response_rate_mean > 0:
587605
self._mode = "rate"
588-
elif inter_chunk_latency > 0:
606+
elif inter_token_latency > 0:
589607
self._mode = "icl"
590608
else:
591609
# Neither specified — no delays (instant response)
@@ -611,10 +629,10 @@ def __init__(
611629
self._rate_mu, self._rate_sigma = 0.0, 0.0
612630
self._rate_mean = response_rate_mean
613631

614-
# Inter-chunk-latency params (convert ms → seconds for asyncio.sleep)
615-
icl_s = inter_chunk_latency / 1000.0
632+
# Inter-token-latency params (convert ms → seconds for asyncio.sleep)
633+
icl_s = inter_token_latency / 1000.0
616634
if icl_s > 0:
617-
self._icl_mu, self._icl_sigma = _lognormal_params(icl_s, inter_chunk_spread)
635+
self._icl_mu, self._icl_sigma = _lognormal_params(icl_s, inter_token_spread)
618636
else:
619637
self._icl_mu, self._icl_sigma = 0.0, 0.0
620638
self._icl_mean = icl_s
@@ -816,7 +834,7 @@ def main():
816834
"--output-len-max",
817835
type=int,
818836
default=None,
819-
help="Maximum output sequence length in chars (default: 8 * output-len-mean)",
837+
help="Maximum output sequence length in chars (default: 2 * output-len-mean)",
820838
)
821839

822840
# Timing mode A: response-rate
@@ -825,7 +843,7 @@ def main():
825843
type=float,
826844
default=0.0,
827845
help="Per-request response rate mean in resp/sec (default: 0 = disabled). "
828-
"Mutually exclusive with --inter-chunk-latency.",
846+
"Mutually exclusive with --inter-token-latency.",
829847
)
830848
parser.add_argument(
831849
"--response-rate-spread",
@@ -834,17 +852,17 @@ def main():
834852
help="CoV for per-request response rate (default: 0.0 = deterministic)",
835853
)
836854

837-
# Timing mode B: inter-chunk-latency
855+
# Timing mode B: inter-token-latency
838856
parser.add_argument(
839-
"--inter-chunk-latency",
857+
"--inter-token-latency",
840858
type=float,
841859
default=0.0,
842-
help="Mean per-chunk delay in milliseconds (default: 0 = disabled). "
843-
"E.g. 20 = 20ms per chunk ≈ 50 tokens/sec. "
860+
help="Per-token generation time (TPOT) in milliseconds (default: 0 = disabled). "
861+
"E.g. 20 = 20ms/token. Actual inter-SSE-event delay = TPOT × stream_interval. "
844862
"Mutually exclusive with --response-rate-mean.",
845863
)
846864
parser.add_argument(
847-
"--inter-chunk-spread",
865+
"--inter-token-spread",
848866
type=float,
849867
default=0.0,
850868
help="CoV for per-chunk delay (default: 0.0 = deterministic)",
@@ -896,8 +914,8 @@ def main():
896914
output_len_max=args.output_len_max,
897915
response_rate_mean=args.response_rate_mean,
898916
response_rate_spread=args.response_rate_spread,
899-
inter_chunk_latency=args.inter_chunk_latency,
900-
inter_chunk_spread=args.inter_chunk_spread,
917+
inter_token_latency=args.inter_token_latency,
918+
inter_token_spread=args.inter_token_spread,
901919
first_chunk_latency=args.first_chunk_latency,
902920
first_chunk_spread=args.first_chunk_spread,
903921
stream=args.stream,

0 commit comments

Comments
 (0)