Skip to content

Commit ff40629

Browse files
authored
Merge pull request #4 from ServerlessLLM/fy/flexible_rps
Fy/flexible rps
2 parents b1a1526 + 5a5f019 commit ff40629

File tree

5 files changed

+87
-122
lines changed

5 files changed

+87
-122
lines changed

tests/test_cli.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ def test_create_trace_generator_synthetic(self):
2121
self.assertIsInstance(generator, SyntheticTraceGenerator)
2222
self.assertEqual(warning, "")
2323

24+
def test_create_trace_generator_synthetic_float_rps(self):
25+
"""Test creating synthetic trace generator with float RPS."""
26+
generator, warning = create_trace_generator("uniform", 2.5, 60)
27+
28+
self.assertIsInstance(generator, SyntheticTraceGenerator)
29+
self.assertEqual(warning, "")
30+
self.assertEqual(generator.rps, 2.5)
31+
2432
def test_create_trace_generator_azure(self):
2533
"""Test creating Azure trace generator."""
2634
generator, warning = create_trace_generator("azure_code", 10, 60)
@@ -29,6 +37,14 @@ def test_create_trace_generator_azure(self):
2937
self.assertIn("RPS parameter (10) is ignored", warning)
3038
self.assertIn("Duration parameter (60) is ignored", warning)
3139

40+
def test_create_trace_generator_azure_float_rps(self):
41+
"""Test creating Azure trace generator with float RPS."""
42+
generator, warning = create_trace_generator("azure_code", 0.5, 60)
43+
44+
self.assertIsInstance(generator, AzureTraceGenerator)
45+
self.assertIn("RPS parameter (0.5) is ignored", warning)
46+
self.assertIn("Duration parameter (60) is ignored", warning)
47+
3248
def test_create_trace_generator_invalid(self):
3349
"""Test creating generator with invalid pattern."""
3450
with self.assertRaises(ValueError):
@@ -74,6 +90,29 @@ def test_cli_with_options(self, mock_run_load_test):
7490
self.assertEqual(result.exit_code, 0)
7591
mock_run_load_test.assert_called_once()
7692

93+
@patch("tracestorm.cli.run_load_test")
94+
def test_cli_with_float_rps(self, mock_run_load_test):
95+
"""Test CLI with float RPS value."""
96+
mock_analyzer = MagicMock()
97+
mock_run_load_test.return_value = ([], mock_analyzer)
98+
99+
result = self.runner.invoke(
100+
main,
101+
[
102+
"--model",
103+
"gpt-3.5-turbo",
104+
"--rps",
105+
"0.5",
106+
"--pattern",
107+
"uniform",
108+
"--duration",
109+
"30",
110+
],
111+
)
112+
113+
self.assertEqual(result.exit_code, 0)
114+
mock_run_load_test.assert_called_once()
115+
77116
def test_cli_invalid_pattern(self):
78117
"""Test CLI with invalid pattern."""
79118
result = self.runner.invoke(

tests/test_trace_generator.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,43 @@ def test_uniform_distribution(self):
2121
expected = [0, 500, 1000, 1500, 2000, 2500]
2222
self.assertEqual(result, expected)
2323

24+
def test_uniform_distribution_float_rps(self):
25+
"""Test uniform distribution pattern with float RPS value."""
26+
generator = SyntheticTraceGenerator(
27+
rps=1.5, pattern="uniform", duration=4
28+
)
29+
# Let's get the actual result and use direct value comparison
30+
result = generator.generate()
31+
# 1.5 RPS for 4 seconds = 6 requests
32+
self.assertEqual(len(result), 6)
33+
# First and last timestamps should be consistent
34+
self.assertEqual(result[0], 0)
35+
self.assertTrue(result[-1] < 4000) # Should be less than duration in ms
36+
2437
def test_invalid_rps(self):
2538
"""Test invalid RPS value."""
2639
with self.assertRaises(ValueError) as context:
2740
SyntheticTraceGenerator(rps=-1, pattern="uniform", duration=10)
2841
self.assertEqual(
29-
str(context.exception), "rps must be a non-negative integer"
42+
str(context.exception), "rps must be a non-negative number"
43+
)
44+
45+
def test_invalid_rps_float(self):
46+
"""Test invalid RPS float value."""
47+
with self.assertRaises(ValueError) as context:
48+
SyntheticTraceGenerator(rps=-0.5, pattern="uniform", duration=10)
49+
self.assertEqual(
50+
str(context.exception), "rps must be a non-negative number"
51+
)
52+
53+
def test_valid_float_rps(self):
54+
"""Test valid float RPS value."""
55+
generator = SyntheticTraceGenerator(
56+
rps=0.5, pattern="uniform", duration=10
3057
)
58+
result = generator.generate()
59+
# 0.5 RPS for 10 seconds = 5 total requests
60+
self.assertEqual(len(result), 5)
3161

3262
def test_invalid_duration(self):
3363
"""Test invalid duration value."""

tracestorm/cli.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Optional, Tuple
2+
from typing import Optional, Tuple, Union
33

44
import click
55

@@ -21,7 +21,10 @@
2121

2222

2323
def create_trace_generator(
24-
pattern: str, rps: int, duration: int, seed: Optional[int] = None
24+
pattern: str,
25+
rps: Union[int, float],
26+
duration: int,
27+
seed: Optional[int] = None,
2528
) -> Tuple[TraceGenerator, str]:
2629
"""
2730
Create appropriate trace generator based on pattern and validate parameters.
@@ -72,8 +75,8 @@ def create_trace_generator(
7275
@click.option("--model", required=True, help="Model name")
7376
@click.option(
7477
"--rps",
75-
type=int,
76-
default=1,
78+
type=float,
79+
default=1.0,
7780
help="Requests per second (only used with synthetic patterns)",
7881
)
7982
@click.option(

tracestorm/main.py

Lines changed: 0 additions & 111 deletions
This file was deleted.

tracestorm/trace_generator.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import tempfile
33
from abc import ABC, abstractmethod
4-
from typing import List, Optional
4+
from typing import List, Optional, Union
55

66
import numpy as np
77
import pandas as pd
@@ -34,19 +34,23 @@ class SyntheticTraceGenerator(TraceGenerator):
3434
"""Generate synthetic traces based on patterns."""
3535

3636
def __init__(
37-
self, rps: int, pattern: str, duration: int, seed: Optional[int] = None
37+
self,
38+
rps: Union[int, float],
39+
pattern: str,
40+
duration: int,
41+
seed: Optional[int] = None,
3842
):
3943
"""
4044
Initialize synthetic trace generator.
4145
4246
Args:
43-
rps (int): Requests per second. Must be non-negative.
47+
rps (Union[int, float]): Requests per second. Must be non-negative.
4448
pattern (str): Distribution pattern ('uniform', 'random', 'poisson', etc.).
4549
duration (int): Total duration in seconds. Must be non-negative.
4650
seed (int): Seed for reproducibility of 'poisson' and 'random' patterns
4751
"""
48-
if not isinstance(rps, int) or rps < 0:
49-
raise ValueError("rps must be a non-negative integer")
52+
if not isinstance(rps, (int, float)) or rps < 0:
53+
raise ValueError("rps must be a non-negative number")
5054
if not isinstance(duration, int) or duration < 0:
5155
raise ValueError("duration must be a non-negative integer")
5256

@@ -57,7 +61,7 @@ def __init__(
5761
np.random.seed(seed)
5862

5963
def generate(self) -> List[int]:
60-
total_requests = self.rps * self.duration
64+
total_requests = int(round(self.rps * self.duration))
6165
total_duration_ms = self.duration * 1000
6266
timestamps = []
6367

0 commit comments

Comments
 (0)