Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ def test_create_trace_generator_synthetic(self):
self.assertIsInstance(generator, SyntheticTraceGenerator)
self.assertEqual(warning, "")

def test_create_trace_generator_synthetic_float_rps(self):
"""Test creating synthetic trace generator with float RPS."""
generator, warning = create_trace_generator("uniform", 2.5, 60)

self.assertIsInstance(generator, SyntheticTraceGenerator)
self.assertEqual(warning, "")
self.assertEqual(generator.rps, 2.5)

def test_create_trace_generator_azure(self):
"""Test creating Azure trace generator."""
generator, warning = create_trace_generator("azure_code", 10, 60)
Expand All @@ -29,6 +37,14 @@ def test_create_trace_generator_azure(self):
self.assertIn("RPS parameter (10) is ignored", warning)
self.assertIn("Duration parameter (60) is ignored", warning)

def test_create_trace_generator_azure_float_rps(self):
"""Test creating Azure trace generator with float RPS."""
generator, warning = create_trace_generator("azure_code", 0.5, 60)

self.assertIsInstance(generator, AzureTraceGenerator)
self.assertIn("RPS parameter (0.5) is ignored", warning)
self.assertIn("Duration parameter (60) is ignored", warning)

def test_create_trace_generator_invalid(self):
"""Test creating generator with invalid pattern."""
with self.assertRaises(ValueError):
Expand Down Expand Up @@ -74,6 +90,29 @@ def test_cli_with_options(self, mock_run_load_test):
self.assertEqual(result.exit_code, 0)
mock_run_load_test.assert_called_once()

@patch("tracestorm.cli.run_load_test")
def test_cli_with_float_rps(self, mock_run_load_test):
"""Test CLI with float RPS value."""
mock_analyzer = MagicMock()
mock_run_load_test.return_value = ([], mock_analyzer)

result = self.runner.invoke(
main,
[
"--model",
"gpt-3.5-turbo",
"--rps",
"0.5",
"--pattern",
"uniform",
"--duration",
"30",
],
)

self.assertEqual(result.exit_code, 0)
mock_run_load_test.assert_called_once()

def test_cli_invalid_pattern(self):
"""Test CLI with invalid pattern."""
result = self.runner.invoke(
Expand Down
32 changes: 31 additions & 1 deletion tests/test_trace_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,43 @@ def test_uniform_distribution(self):
expected = [0, 500, 1000, 1500, 2000, 2500]
self.assertEqual(result, expected)

def test_uniform_distribution_float_rps(self):
"""Test uniform distribution pattern with float RPS value."""
generator = SyntheticTraceGenerator(
rps=1.5, pattern="uniform", duration=4
)
# Let's get the actual result and use direct value comparison
result = generator.generate()
# 1.5 RPS for 4 seconds = 6 requests
self.assertEqual(len(result), 6)
# First and last timestamps should be consistent
self.assertEqual(result[0], 0)
self.assertTrue(result[-1] < 4000) # Should be less than duration in ms

def test_invalid_rps(self):
"""Test invalid RPS value."""
with self.assertRaises(ValueError) as context:
SyntheticTraceGenerator(rps=-1, pattern="uniform", duration=10)
self.assertEqual(
str(context.exception), "rps must be a non-negative integer"
str(context.exception), "rps must be a non-negative number"
)

def test_invalid_rps_float(self):
"""Test invalid RPS float value."""
with self.assertRaises(ValueError) as context:
SyntheticTraceGenerator(rps=-0.5, pattern="uniform", duration=10)
self.assertEqual(
str(context.exception), "rps must be a non-negative number"
)

def test_valid_float_rps(self):
"""Test valid float RPS value."""
generator = SyntheticTraceGenerator(
rps=0.5, pattern="uniform", duration=10
)
result = generator.generate()
# 0.5 RPS for 10 seconds = 5 total requests
self.assertEqual(len(result), 5)

def test_invalid_duration(self):
"""Test invalid duration value."""
Expand Down
11 changes: 7 additions & 4 deletions tracestorm/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import click

Expand All @@ -21,7 +21,10 @@


def create_trace_generator(
pattern: str, rps: int, duration: int, seed: Optional[int] = None
pattern: str,
rps: Union[int, float],
duration: int,
seed: Optional[int] = None,
) -> Tuple[TraceGenerator, str]:
"""
Create appropriate trace generator based on pattern and validate parameters.
Expand Down Expand Up @@ -72,8 +75,8 @@ def create_trace_generator(
@click.option("--model", required=True, help="Model name")
@click.option(
"--rps",
type=int,
default=1,
type=float,
default=1.0,
help="Requests per second (only used with synthetic patterns)",
)
@click.option(
Expand Down
111 changes: 0 additions & 111 deletions tracestorm/main.py

This file was deleted.

16 changes: 10 additions & 6 deletions tracestorm/trace_generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import tempfile
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import List, Optional, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -34,19 +34,23 @@ class SyntheticTraceGenerator(TraceGenerator):
"""Generate synthetic traces based on patterns."""

def __init__(
self, rps: int, pattern: str, duration: int, seed: Optional[int] = None
self,
rps: Union[int, float],
pattern: str,
duration: int,
seed: Optional[int] = None,
):
"""
Initialize synthetic trace generator.

Args:
rps (int): Requests per second. Must be non-negative.
rps (Union[int, float]): Requests per second. Must be non-negative.
pattern (str): Distribution pattern ('uniform', 'random', 'poisson', etc.).
duration (int): Total duration in seconds. Must be non-negative.
seed (int): Seed for reproducibility of 'poisson' and 'random' patterns
"""
if not isinstance(rps, int) or rps < 0:
raise ValueError("rps must be a non-negative integer")
if not isinstance(rps, (int, float)) or rps < 0:
raise ValueError("rps must be a non-negative number")
if not isinstance(duration, int) or duration < 0:
raise ValueError("duration must be a non-negative integer")

Expand All @@ -57,7 +61,7 @@ def __init__(
np.random.seed(seed)

def generate(self) -> List[int]:
total_requests = self.rps * self.duration
total_requests = int(round(self.rps * self.duration))
total_duration_ms = self.duration * 1000
timestamps = []

Expand Down