From 4e6baf297ac4e88b416751b9eebc55d98a068f94 Mon Sep 17 00:00:00 2001 From: Yao Fu Date: Fri, 25 Apr 2025 13:46:37 +0100 Subject: [PATCH 1/3] refactor: remove main.py --- tracestorm/main.py | 111 --------------------------------------------- 1 file changed, 111 deletions(-) delete mode 100644 tracestorm/main.py diff --git a/tracestorm/main.py b/tracestorm/main.py deleted file mode 100644 index c6d4ef2..0000000 --- a/tracestorm/main.py +++ /dev/null @@ -1,111 +0,0 @@ -import argparse -import multiprocessing -import os - -from tracestorm.logger import init_logger -from tracestorm.request_generator import generate_request -from tracestorm.result_analyzer import ResultAnalyzer -from tracestorm.trace_generator import generate_trace -from tracestorm.trace_player import play -from tracestorm.utils import round_robin_shard - -logger = init_logger(__name__) - - -def get_args(): - parser = argparse.ArgumentParser( - description="Run a replay of OpenAI requests." - ) - parser.add_argument("--model", required=True, help="Model name") - parser.add_argument( - "--rps", type=int, default=1, help="Requests per second" - ) - parser.add_argument( - "--pattern", default="uniform", help="Pattern for generating trace" - ) - parser.add_argument( - "--duration", type=int, default=10, help="Duration in seconds" - ) - parser.add_argument( - "--subprocesses", type=int, default=1, help="Number of subprocesses" - ) - parser.add_argument( - "--base-url", - default=os.environ.get("OPENAI_BASE_URL", "http://localhost:8000/v1"), - help="OpenAI Base URL", - ) - parser.add_argument( - "--api-key", - default=os.environ.get("OPENAI_API_KEY", "none"), - help="OpenAI API Key", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - - raw_trace = generate_trace(args.rps, args.pattern, args.duration) - total_requests = len(raw_trace) - logger.debug(f"Raw trace: {raw_trace}") - - requests = generate_request(args.model, total_requests) - logger.debug(f"Requests: {requests}") - - ipc_queue = multiprocessing.Queue() - processes = [] - - if total_requests == 0: - logger.warning("No requests to process. Trace is empty.") - return - - # Launch subprocesses - for i, (partial_trace, partial_requests) in enumerate( - round_robin_shard(raw_trace, requests, args.subprocesses), start=1 - ): - p = multiprocessing.Process( - target=play, - args=( - f"TracePlayer-{i}", - partial_trace, - partial_requests, - args.base_url, - args.api_key, - ipc_queue, - ), - ) - p.start() - processes.append(p) - - results_collected = 0 - aggregated_results = [] - while results_collected < total_requests: - try: - name, timestamp, resp = ipc_queue.get(timeout=30) - results_collected += 1 - logger.info( - f"Received result from {name} for timestamp {timestamp}: {resp['token_count']} tokens" - ) - aggregated_results.append((name, timestamp, resp)) - except Exception as e: - logger.error( - f"Timeout or error reading from IPC queue: {e}", exc_info=True - ) - break - - for p in processes: - p.join() - - logger.info("All subprocesses have finished.") - - logger.debug(f"Aggregated results: {aggregated_results}") - - result_analyzer = ResultAnalyzer() - result_analyzer.store_raw_results(aggregated_results) - print(result_analyzer) - result_analyzer.plot_cdf() - - -if __name__ == "__main__": - main() From ca73cd870cfe30e3813c1045db2d52bae600d469 Mon Sep 17 00:00:00 2001 From: Yao Fu Date: Fri, 25 Apr 2025 14:04:12 +0100 Subject: [PATCH 2/3] feat(trace_generator): support floating-point RPS values Add support for floating-point requests per second (RPS) values in the trace generator, making the load testing more flexible for low-throughput scenarios. Changes include: - Update type hints from `int` to `Union[int, float]` for RPS parameter - Cast total_requests to int when calculating from float RPS - Change CLI option type from int to float with default value 1.0 - Update validation error message to reflect new type requirement - Add comprehensive test cases for float RPS values --- tests/test_cli.py | 39 +++++++++++++++++++++++++++++++++++ tests/test_trace_generator.py | 32 +++++++++++++++++++++++++++- tracestorm/cli.py | 11 ++++++---- tracestorm/trace_generator.py | 16 ++++++++------ 4 files changed, 87 insertions(+), 11 deletions(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index 8ed12a2..30fb5dd 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -21,6 +21,14 @@ def test_create_trace_generator_synthetic(self): self.assertIsInstance(generator, SyntheticTraceGenerator) self.assertEqual(warning, "") + def test_create_trace_generator_synthetic_float_rps(self): + """Test creating synthetic trace generator with float RPS.""" + generator, warning = create_trace_generator("uniform", 2.5, 60) + + self.assertIsInstance(generator, SyntheticTraceGenerator) + self.assertEqual(warning, "") + self.assertEqual(generator.rps, 2.5) + def test_create_trace_generator_azure(self): """Test creating Azure trace generator.""" generator, warning = create_trace_generator("azure_code", 10, 60) @@ -29,6 +37,14 @@ def test_create_trace_generator_azure(self): self.assertIn("RPS parameter (10) is ignored", warning) self.assertIn("Duration parameter (60) is ignored", warning) + def test_create_trace_generator_azure_float_rps(self): + """Test creating Azure trace generator with float RPS.""" + generator, warning = create_trace_generator("azure_code", 0.5, 60) + + self.assertIsInstance(generator, AzureTraceGenerator) + self.assertIn("RPS parameter (0.5) is ignored", warning) + self.assertIn("Duration parameter (60) is ignored", warning) + def test_create_trace_generator_invalid(self): """Test creating generator with invalid pattern.""" with self.assertRaises(ValueError): @@ -74,6 +90,29 @@ def test_cli_with_options(self, mock_run_load_test): self.assertEqual(result.exit_code, 0) mock_run_load_test.assert_called_once() + @patch("tracestorm.cli.run_load_test") + def test_cli_with_float_rps(self, mock_run_load_test): + """Test CLI with float RPS value.""" + mock_analyzer = MagicMock() + mock_run_load_test.return_value = ([], mock_analyzer) + + result = self.runner.invoke( + main, + [ + "--model", + "gpt-3.5-turbo", + "--rps", + "0.5", + "--pattern", + "uniform", + "--duration", + "30", + ], + ) + + self.assertEqual(result.exit_code, 0) + mock_run_load_test.assert_called_once() + def test_cli_invalid_pattern(self): """Test CLI with invalid pattern.""" result = self.runner.invoke( diff --git a/tests/test_trace_generator.py b/tests/test_trace_generator.py index 707d89c..4bdf4e0 100644 --- a/tests/test_trace_generator.py +++ b/tests/test_trace_generator.py @@ -21,13 +21,43 @@ def test_uniform_distribution(self): expected = [0, 500, 1000, 1500, 2000, 2500] self.assertEqual(result, expected) + def test_uniform_distribution_float_rps(self): + """Test uniform distribution pattern with float RPS value.""" + generator = SyntheticTraceGenerator( + rps=1.5, pattern="uniform", duration=4 + ) + # Let's get the actual result and use direct value comparison + result = generator.generate() + # 1.5 RPS for 4 seconds = 6 requests + self.assertEqual(len(result), 6) + # First and last timestamps should be consistent + self.assertEqual(result[0], 0) + self.assertTrue(result[-1] < 4000) # Should be less than duration in ms + def test_invalid_rps(self): """Test invalid RPS value.""" with self.assertRaises(ValueError) as context: SyntheticTraceGenerator(rps=-1, pattern="uniform", duration=10) self.assertEqual( - str(context.exception), "rps must be a non-negative integer" + str(context.exception), "rps must be a non-negative number" + ) + + def test_invalid_rps_float(self): + """Test invalid RPS float value.""" + with self.assertRaises(ValueError) as context: + SyntheticTraceGenerator(rps=-0.5, pattern="uniform", duration=10) + self.assertEqual( + str(context.exception), "rps must be a non-negative number" + ) + + def test_valid_float_rps(self): + """Test valid float RPS value.""" + generator = SyntheticTraceGenerator( + rps=0.5, pattern="uniform", duration=10 ) + result = generator.generate() + # 0.5 RPS for 10 seconds = 5 total requests + self.assertEqual(len(result), 5) def test_invalid_duration(self): """Test invalid duration value.""" diff --git a/tracestorm/cli.py b/tracestorm/cli.py index 2647858..0e6c9ce 100644 --- a/tracestorm/cli.py +++ b/tracestorm/cli.py @@ -1,5 +1,5 @@ import os -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import click @@ -21,7 +21,10 @@ def create_trace_generator( - pattern: str, rps: int, duration: int, seed: Optional[int] = None + pattern: str, + rps: Union[int, float], + duration: int, + seed: Optional[int] = None, ) -> Tuple[TraceGenerator, str]: """ Create appropriate trace generator based on pattern and validate parameters. @@ -72,8 +75,8 @@ def create_trace_generator( @click.option("--model", required=True, help="Model name") @click.option( "--rps", - type=int, - default=1, + type=float, + default=1.0, help="Requests per second (only used with synthetic patterns)", ) @click.option( diff --git a/tracestorm/trace_generator.py b/tracestorm/trace_generator.py index 36b98b5..3974fd8 100644 --- a/tracestorm/trace_generator.py +++ b/tracestorm/trace_generator.py @@ -1,7 +1,7 @@ import os import tempfile from abc import ABC, abstractmethod -from typing import List, Optional +from typing import List, Optional, Union import numpy as np import pandas as pd @@ -34,19 +34,23 @@ class SyntheticTraceGenerator(TraceGenerator): """Generate synthetic traces based on patterns.""" def __init__( - self, rps: int, pattern: str, duration: int, seed: Optional[int] = None + self, + rps: Union[int, float], + pattern: str, + duration: int, + seed: Optional[int] = None, ): """ Initialize synthetic trace generator. Args: - rps (int): Requests per second. Must be non-negative. + rps (Union[int, float]): Requests per second. Must be non-negative. pattern (str): Distribution pattern ('uniform', 'random', 'poisson', etc.). duration (int): Total duration in seconds. Must be non-negative. seed (int): Seed for reproducibility of 'poisson' and 'random' patterns """ - if not isinstance(rps, int) or rps < 0: - raise ValueError("rps must be a non-negative integer") + if not isinstance(rps, (int, float)) or rps < 0: + raise ValueError("rps must be a non-negative number") if not isinstance(duration, int) or duration < 0: raise ValueError("duration must be a non-negative integer") @@ -57,7 +61,7 @@ def __init__( np.random.seed(seed) def generate(self) -> List[int]: - total_requests = self.rps * self.duration + total_requests = int(self.rps * self.duration) total_duration_ms = self.duration * 1000 timestamps = [] From 5a5f0191458cb1a45bd172d2f8fb0ed30729bbc9 Mon Sep 17 00:00:00 2001 From: Yao Fu Date: Fri, 25 Apr 2025 14:12:44 +0100 Subject: [PATCH 3/3] fix: round before int --- tracestorm/trace_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tracestorm/trace_generator.py b/tracestorm/trace_generator.py index 3974fd8..4b41968 100644 --- a/tracestorm/trace_generator.py +++ b/tracestorm/trace_generator.py @@ -61,7 +61,7 @@ def __init__( np.random.seed(seed) def generate(self) -> List[int]: - total_requests = int(self.rps * self.duration) + total_requests = int(round(self.rps * self.duration)) total_duration_ms = self.duration * 1000 timestamps = []