diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4a924048..00fc9d3a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,6 +18,17 @@ repos: args: [--fix, --exit-non-zero-on-fix] - id: ruff-format + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.9.0 + hooks: + - id: mypy + args: [--config-file=./pyproject.toml, src, tests] + pass_filenames: false + verbose: true + additional_dependencies: + - types-PyYAML==6.0.12 + - types-requests>=2.32.4 + - repo: https://github.com/pre-commit/mirrors-prettier rev: v3.1.0 hooks: diff --git a/examples/04_GPTOSS120B_Example/sglang_gptoss_120b_example.yaml b/examples/04_GPTOSS120B_Example/sglang_gptoss_120b_example.yaml index de0dd14f..146ee2e8 100644 --- a/examples/04_GPTOSS120B_Example/sglang_gptoss_120b_example.yaml +++ b/examples/04_GPTOSS120B_Example/sglang_gptoss_120b_example.yaml @@ -1,6 +1,7 @@ name: "gpt-oss-120b-benchmark" version: "1.0" type: "online" +timeout: 60000 model_params: name: "openai/gpt-oss-120b" @@ -38,8 +39,8 @@ datasets: num_repeats: 5 settings: runtime: - min_duration_ms: 300 - max_duration_ms: 6000 + min_duration_ms: 3000 + max_duration_ms: 60000 scheduler_random_seed: 42 dataloader_random_seed: 42 diff --git a/pyproject.toml b/pyproject.toml index c315e685..675fcf1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -199,3 +199,17 @@ exclude_lines = [ "class .*\\\\bProtocol\\):", "@(abc\\\\.)?abstractmethod", ] + +[tool.mypy] +ignore_missing_imports = true +exclude = [ + "openai_types_gen\\.py$", + "metrics/reporter\\.py$", +] + +[[tool.mypy.overrides]] +module = [ + "inference_endpoint.openai.openai_types_gen", + "inference_endpoint.metrics.reporter", +] +ignore_errors = true diff --git a/src/inference_endpoint/commands/benchmark.py b/src/inference_endpoint/commands/benchmark.py index 0056f53d..15f057db 100644 --- a/src/inference_endpoint/commands/benchmark.py +++ b/src/inference_endpoint/commands/benchmark.py @@ -28,6 +28,7 @@ import uuid from dataclasses import dataclass from pathlib import Path +from typing import Any from urllib.parse import urljoin from tqdm import tqdm @@ -40,7 +41,6 @@ APIType, BenchmarkConfig, ClientSettings, - Dataset, DatasetType, EndpointConfig, LoadPattern, @@ -51,11 +51,16 @@ RuntimeConfig, Settings, StreamingMode, + SystemDefaults, TestMode, TestType, ) +from inference_endpoint.config.schema import ( + Dataset as DatasetConfig, +) from inference_endpoint.config.yaml_loader import ConfigError, ConfigLoader from inference_endpoint.core.types import QueryResult +from inference_endpoint.dataset_manager.dataset import Dataset from inference_endpoint.dataset_manager.factory import DataLoaderFactory from inference_endpoint.endpoint_client.config import HTTPClientConfig from inference_endpoint.endpoint_client.cpu_affinity import pin_loadgen @@ -134,7 +139,7 @@ def on_complete_hook(self, result: QueryResult): if self.pbar: self.pbar.set_postfix(refresh=True, errors=len(self.errors)) elif self.collect_responses: - self.responses[result.id] = result.response_output + self.responses[result.id] = result.get_response_output_string() if self.pbar: self.pbar.update(1) @@ -147,7 +152,7 @@ class AccuracyConfiguration: dataset_name: str dataset: Dataset report_dir: os.PathLike - ground_truth_column: str + ground_truth_column: str | None num_repeats: int @@ -227,9 +232,7 @@ async def run_benchmark_command(args: argparse.Namespace) -> None: elif benchmark_mode_str in ("offline", "online"): # ===== CLI MODE - Build config from CLI params ===== benchmark_mode = TestType(benchmark_mode_str) # TestType values are lowercase - effective_config: BenchmarkConfig = _build_config_from_cli( - args, benchmark_mode_str - ) + effective_config = _build_config_from_cli(args, benchmark_mode_str) test_mode = ( TestMode(args.mode) if getattr(args, "mode", None) else TestMode.PERF ) @@ -294,7 +297,7 @@ def _build_config_from_cli( version="1.0", type=TestType.OFFLINE if benchmark_mode == "offline" else TestType.ONLINE, datasets=[ - Dataset( + DatasetConfig( name=args.dataset.stem, type=DatasetType.PERFORMANCE, path=str(args.dataset), @@ -343,7 +346,6 @@ def _build_config_from_cli( api_type=api_type, ), metrics=Metrics(), - baseline=None, # CLI mode doesn't use baseline report_dir=report_dir, timeout=timeout, verbose=verbose_level > 0, @@ -464,6 +466,18 @@ def _run_benchmark( if len(accuracy_configs) > 0: # Pack the evaluation parameters for each accuracy dataset for acc_config in accuracy_configs: + # Type narrowing: ensure accuracy_config is not None + assert ( + acc_config.accuracy_config is not None + ), f"accuracy_config must be set for dataset {acc_config.name}" + # Type narrowing: ensure required fields are not None + assert ( + acc_config.accuracy_config.eval_method is not None + ), f"eval_method must be set for dataset {acc_config.name}" + assert ( + acc_config.accuracy_config.extractor is not None + ), f"extractor must be set for dataset {acc_config.name}" + dataset = DataLoaderFactory.create_loader( acc_config, num_repeats=acc_config.accuracy_config.num_repeats ) @@ -475,7 +489,7 @@ def _run_benchmark( Extractor.get(acc_config.accuracy_config.extractor), acc_config.name, dataset, - config.report_dir, + report_dir, acc_config.accuracy_config.ground_truth, acc_config.accuracy_config.num_repeats, ) @@ -554,18 +568,18 @@ def _run_benchmark( # Create endpoint client endpoints = config.endpoint_config.endpoints + assert endpoints is not None num_workers = config.settings.client.workers logger.info(f"Connecting: {endpoints}") tmp_dir = tempfile.mkdtemp(prefix="inference_endpoint_") try: + api_type: APIType = config.endpoint_config.api_type + assert api_type is not None http_config = HTTPClientConfig( - endpoint_urls=[ - urljoin(e, config.endpoint_config.api_type.default_route()) - for e in endpoints - ], - api_type=config.endpoint_config.api_type, + endpoint_urls=[urljoin(e, api_type.default_route()) for e in endpoints], + api_type=api_type, num_workers=num_workers, record_worker_events=config.settings.client.record_worker_events, event_logs_dir=report_dir, @@ -595,7 +609,9 @@ def _run_benchmark( report_dir=report_dir, tokenizer_override=tokenizer, accuracy_datasets=accuracy_datasets, - max_shutdown_timeout_s=config.timeout if config.timeout else None, + max_shutdown_timeout_s=config.timeout + if config.timeout + else SystemDefaults.DEFAULT_TIMEOUT, dump_events_log=True, ) @@ -622,6 +638,7 @@ def signal_handler(signum, frame): ground_truth_column=eval_config.ground_truth_column, ) score, n_repeats = scorer_instance.score() + assert eval_config.dataset.data is not None accuracy_scores[eval_config.dataset_name] = { "dataset_name": eval_config.dataset_name, "num_samples": len(eval_config.dataset.data), @@ -665,7 +682,7 @@ def signal_handler(signum, frame): logger.warning(f" ... +{len(response_collector.errors) - 3} more") try: - results = { + results: dict[str, Any] = { "config": { "endpoint": endpoints, "mode": test_mode, diff --git a/src/inference_endpoint/commands/probe.py b/src/inference_endpoint/commands/probe.py index 8ede0f39..cb704e03 100644 --- a/src/inference_endpoint/commands/probe.py +++ b/src/inference_endpoint/commands/probe.py @@ -136,6 +136,7 @@ async def run_probe_command(args: argparse.Namespace) -> None: ): try: # Schedule receive on client's event loop and await the result + assert client.loop is not None, "Client loop should be initialized" future = asyncio.run_coroutine_threadsafe(client.recv(), client.loop) result = await asyncio.wrap_future(future) @@ -163,11 +164,18 @@ async def run_probe_command(args: argparse.Namespace) -> None: continue latency_ms = (time.time() - start_times[query_id]) * 1000 + # Normalize response_output for logging + response_output = result.get_response_output_string() + if response_output is None: + response_output = "" + if result.error: errors.append(f"{query_id}: {result.error}") else: latencies.append(latency_ms) - responses.append((query_id, result.response_output)) + responses.append( + (query_id, response_output if response_output else "") + ) # Simple progress indicator if ( @@ -175,9 +183,7 @@ async def run_probe_command(args: argparse.Namespace) -> None: or len(received_ids) == num_expected ): output_preview = ( - result.response_output[:100] - if result.response_output - else "(no output)" + response_output[:100] if response_output else "(no output)" ) logger.info( f" Processed {len(received_ids)}/{num_expected} responses : {query_id} : {output_preview}" diff --git a/src/inference_endpoint/config/rulesets/mlcommons/datasets.py b/src/inference_endpoint/config/rulesets/mlcommons/datasets.py index f86771a3..84569d67 100644 --- a/src/inference_endpoint/config/rulesets/mlcommons/datasets.py +++ b/src/inference_endpoint/config/rulesets/mlcommons/datasets.py @@ -64,4 +64,4 @@ def _disallow_instantiation(cls, *args, **kwargs): ) -_Dataset.__new__ = _disallow_instantiation +_Dataset.__new__ = _disallow_instantiation # type: ignore[method-assign] diff --git a/src/inference_endpoint/config/rulesets/mlcommons/models.py b/src/inference_endpoint/config/rulesets/mlcommons/models.py index 7cda69c3..e03b72e6 100644 --- a/src/inference_endpoint/config/rulesets/mlcommons/models.py +++ b/src/inference_endpoint/config/rulesets/mlcommons/models.py @@ -154,4 +154,4 @@ def _disallow_instantiation(cls, *args, **kwargs): ) -_Model.__new__ = _disallow_instantiation +_Model.__new__ = _disallow_instantiation # type: ignore[method-assign] diff --git a/src/inference_endpoint/config/rulesets/mlcommons/rules.py b/src/inference_endpoint/config/rulesets/mlcommons/rules.py index 121e2049..ae695c39 100644 --- a/src/inference_endpoint/config/rulesets/mlcommons/rules.py +++ b/src/inference_endpoint/config/rulesets/mlcommons/rules.py @@ -26,6 +26,7 @@ from .... import metrics from ...ruleset_base import BenchmarkSuiteRuleset from ...runtime_settings import RuntimeSettings +from ...schema import SystemDefaults from ...user_config import UserConfig from . import models @@ -38,8 +39,12 @@ class PerModelRuleset: min_duration_ms_valid: int = ( 10 * 60 * 1000 ) # Minimum duration in milliseconds required for a valid run - max_duration_ms_valid: int = None # Maximum duration in milliseconds. Used as a timeout / kill for a benchmark run. Set to None for no timeout. - min_sample_count_valid: int = None # Minimum number of samples required to be sent to the SUT for a valid run, if None, no minimum is enforced + max_duration_ms_valid: int | None = ( + None # Maximum duration in milliseconds. Used as a timeout / kill for a benchmark run. Set to None for no timeout. + ) + min_sample_count_valid: int | None = ( + None # Minimum number of samples required to be sent to the SUT for a valid run, if None, no minimum is enforced + ) metric: type[metrics.Metric] | None = ( None # any subclass of Metric. Used as metric to evaluate the performance of the benchmark. ) @@ -55,7 +60,9 @@ class PerQueryRuleset(PerModelRuleset): target_latency_percentile: float = ( 99.0 # Percentile of per-query latencies to use for metric comparison ) - max_latency_threshold_ms: int = None # Maximum latency threshold in milliseconds for the specified percentile latency allowed for a valid run. + max_latency_threshold_ms: int | None = ( + None # Maximum latency threshold in milliseconds for the specified percentile latency allowed for a valid run. + ) reported_metrics: list[type[metrics.Metric]] = field( default_factory=lambda: [metrics.Throughput] ) @@ -65,10 +72,10 @@ class PerQueryRuleset(PerModelRuleset): class TokenBasedRuleset(PerModelRuleset): min_sample_count_valid: int = 270336 metric: type[metrics.Metric] = metrics.Throughput - max_ttft_latency_ms: int = ( + max_ttft_latency_ms: int | None = ( None # Maximum TTFT latency in milliseconds allowed for a valid run ) - max_tpot_latency_ms: int = ( + max_tpot_latency_ms: int | None = ( None # Maximum TPoT latency in milliseconds allowed for a valid run ) reported_metrics: list[type[metrics.Metric]] = field( @@ -148,17 +155,19 @@ def apply_user_config( ruleset = self.benchmark_rulesets[model][opt_prio] - metric_target = None - if user_config.user_metric_target: + metric_target: metrics.Metric | None = None + if user_config.user_metric_target and ruleset.metric: metric_target = ruleset.metric(user_config.user_metric_target) - reported_metrics = [] + reported_metrics: list[metrics.Metric] = [] for mtype in ruleset.reported_metrics: - if mtype == ruleset.metric: + if mtype == ruleset.metric and metric_target: reported_metrics.append(metric_target) elif mtype == metrics.TTFT: + assert isinstance(ruleset, TokenBasedRuleset) reported_metrics.append(metrics.TTFT(ruleset.max_ttft_latency_ms)) elif mtype == metrics.TPOT: + assert isinstance(ruleset, TokenBasedRuleset) reported_metrics.append(metrics.TPOT(ruleset.max_tpot_latency_ms)) elif mtype == metrics.QueryLatency and ruleset.metric == metrics.Throughput: # If we specify throughput and want to also report per query latency, infer latency from inverting qps. @@ -166,6 +175,7 @@ def apply_user_config( metrics.QueryLatency(target_qps=user_config.user_metric_target) ) elif mtype == metrics.Throughput and ruleset.metric == metrics.QueryLatency: + assert user_config.user_metric_target is not None # If we specify per query latency, infer qps by inverting target_qps = 1000 / user_config.user_metric_target reported_metrics.append(metrics.Throughput(target_qps=target_qps)) @@ -182,7 +192,7 @@ def apply_user_config( if user_config.max_duration_ms is not None: max_duration_ms = user_config.max_duration_ms assert ( - max_duration_ms >= min_duration_ms + max_duration_ms is not None and max_duration_ms >= min_duration_ms ), "Max duration must be greater than or equal to min duration" n_samples_from_dataset = model.dataset.size @@ -198,13 +208,15 @@ def apply_user_config( min_sample_count = user_config.min_sample_count return _RuntimeSettings( - metric_target=metric_target, + metric_target=metric_target + if metric_target is not None + else SystemDefaults.DEFAULT_METRIC, reported_metrics=reported_metrics, min_duration_ms=min_duration_ms, max_duration_ms=max_duration_ms, n_samples_from_dataset=n_samples_from_dataset, n_samples_to_issue=total_sample_count, - min_sample_count=min_sample_count, + min_sample_count=min_sample_count if min_sample_count is not None else 1, rng_sched=random.Random(self.scheduler_rng_seed), rng_sample_index=random.Random(self.sample_index_rng_seed), load_pattern=None, # not part user config diff --git a/src/inference_endpoint/config/runtime_settings.py b/src/inference_endpoint/config/runtime_settings.py index b51d9aa9..fa7f06bb 100644 --- a/src/inference_endpoint/config/runtime_settings.py +++ b/src/inference_endpoint/config/runtime_settings.py @@ -164,7 +164,7 @@ def _from_config_default( # Apply overrides kwargs.update(overrides) - return cls(**kwargs) + return cls(**kwargs) # type: ignore[arg-type] def total_samples_to_issue( self, padding_factor: float = 1.1, align_to_dataset_size: bool = True diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index 4157d275..9a1ccb8d 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -22,6 +22,7 @@ from enum import Enum from pathlib import Path +from typing import ClassVar import yaml from pydantic import BaseModel, Field @@ -30,6 +31,11 @@ from .ruleset_base import BenchmarkSuiteRuleset +class SystemDefaults(BaseModel): + DEFAULT_TIMEOUT: ClassVar[float] = 300.0 + DEFAULT_METRIC: ClassVar[metrics.Metric] = metrics.Throughput(0.0) + + class APIType(str, Enum): OPENAI = "openai" SGLANG = "sglang" @@ -198,16 +204,22 @@ class Dataset(BaseModel): format: str | None = None samples: int | None = None eval_method: EvalMethod | None = None - parser: dict | None = None + parser: dict[str, str] | None = None accuracy_config: AccuracyConfig | None = None class AccuracyConfig(BaseModel): """Accuracy configuration. - The eval_method is the method to use to evaluate the accuracy of the model. Currently only "pass_at_1" is supported. - The ground_truth is the column in the dataset that contains the ground truth. Defaults to "ground_truth" if not specified. - The extractor is the extractor to use to extract the ground truth from the output. Currently "boxed_math_extractor" and "abcd_extractor" are supported. - The num_repeats is the number of times to repeat the dataset for evaluation. Defaults to 1 if not specified. + + The eval_method is the method to use to evaluate the accuracy of the model. + Currently only "pass_at_1" is supported. + The ground_truth is the column in the dataset that contains the ground truth. + Defaults to "ground_truth" if not specified. + The extractor is the extractor to use to extract the ground truth from the output. + Currently "boxed_math_extractor" and "abcd_extractor" are supported. + The num_repeats is the number of times to repeat the dataset for evaluation. + Defaults to 1 if not specified. + Example: accuracy_config: eval_method: "pass_at_1" @@ -373,7 +385,7 @@ class BenchmarkConfig(BaseModel): # workers are assigned endpoints in a round-robin manner endpoint_config: EndpointConfig = Field(default_factory=EndpointConfig) report_dir: Path | None = None - timeout: int | None = None + timeout: float | None = None verbose: bool = False # CPU affinity for loadgen and worker processes: # - True = auto (compute optimal NUMA-aware plan) diff --git a/src/inference_endpoint/core/types.py b/src/inference_endpoint/core/types.py index fff8c07c..c02553da 100644 --- a/src/inference_endpoint/core/types.py +++ b/src/inference_endpoint/core/types.py @@ -52,7 +52,7 @@ class QueryStatus(Enum): _OUTPUT_RESULT_TYPE = str | tuple[str, ...] | _OUTPUT_DICT_TYPE | None -class Query(msgspec.Struct, kw_only=True): +class Query(msgspec.Struct, kw_only=True): # type: ignore[call-arg] """Represents a single inference query to be sent to an endpoint. A Query encapsulates all information needed to make an HTTP request to @@ -80,7 +80,7 @@ class Query(msgspec.Struct, kw_only=True): created_at: float = msgspec.field(default_factory=time.time) -class QueryResult(msgspec.Struct, tag="query_result", kw_only=True, frozen=True): +class QueryResult(msgspec.Struct, tag="query_result", kw_only=True, frozen=True): # type: ignore[call-arg] """Result of a completed inference query. Represents the outcome of processing a Query, including the response text, @@ -112,7 +112,7 @@ class QueryResult(msgspec.Struct, tag="query_result", kw_only=True, frozen=True) response_output: _OUTPUT_RESULT_TYPE = None metadata: dict[str, Any] = msgspec.field(default_factory=dict) error: str | None = None - completed_at: int = msgspec.UNSET + completed_at: int | msgspec.UnsetType = msgspec.UNSET def __post_init__(self): """Set completion timestamp automatically. @@ -142,8 +142,19 @@ def __post_init__(self): if isinstance(v, list): self.response_output[k] = tuple(v) + def get_response_output_string(self) -> str: + """Get the response output as a string.""" + if isinstance(self.response_output, tuple): + return "".join(self.response_output) + elif isinstance(self.response_output, dict): + return str(self.response_output) + elif isinstance(self.response_output, str): + return self.response_output + else: + return "" + -class StreamChunk(msgspec.Struct, tag="stream_chunk", kw_only=True): +class StreamChunk(msgspec.Struct, tag="stream_chunk", kw_only=True): # type: ignore[call-arg] """A single chunk from a streaming inference response. Streaming responses are sent incrementally as the model generates text. diff --git a/src/inference_endpoint/dataset_manager/dataset.py b/src/inference_endpoint/dataset_manager/dataset.py index b88a6afa..db09395a 100644 --- a/src/inference_endpoint/dataset_manager/dataset.py +++ b/src/inference_endpoint/dataset_manager/dataset.py @@ -104,7 +104,7 @@ def __init__( self, *args, **kwargs, - ): + ) -> None: super().__init__(*args, **kwargs) self.dataframe: pd.DataFrame | None = None @@ -118,12 +118,13 @@ def get_dataframe(self) -> pd.DataFrame: def get_num_samples(self) -> int: """Get the number of samples in the dataset.""" + assert self.dataframe is not None return len(self.dataframe) @classmethod def get_loader( cls, file_path: os.PathLike, format: DatasetFormat | None = None - ) -> "DatafileLoader": + ) -> type["DatafileLoader"]: """Get the loader for the dataset.""" if format is not None: @@ -142,7 +143,7 @@ def __init__( file_path: Path | str, *args, **kwargs, - ): + ) -> None: super().__init__(*args, **kwargs) self.parquet_path = Path(file_path) @@ -157,7 +158,7 @@ def __init__( file_path: Path | str | None = None, *args, **kwargs, - ): + ) -> None: super().__init__(*args, **kwargs) self.file_path = file_path self.dataset_name = kwargs.get("dataset_name", None) @@ -182,7 +183,7 @@ def __init__( csv_path: Path | str, *args, **kwargs, - ): + ) -> None: super().__init__(*args, **kwargs) self.csv_path = Path(csv_path) @@ -196,7 +197,7 @@ def __init__( file_path: Path | str, *args, **kwargs, - ): + ) -> None: super().__init__(*args, **kwargs) self.pickle_path = Path(file_path) @@ -210,7 +211,7 @@ def __init__( jsonl_path: Path | str, *args, **kwargs, - ): + ) -> None: super().__init__(*args, **kwargs) self.jsonl_path = Path(jsonl_path) @@ -224,7 +225,7 @@ def __init__( json_path: Path | str, *args, **kwargs, - ): + ) -> None: super().__init__(*args, **kwargs) self.json_path = Path(json_path) @@ -294,6 +295,9 @@ class Dataset: PREDEFINED: ClassVar[dict[str, type["Dataset"]]] = {} """A dictionary of predefined datasets, as subclasses of Dataset.""" + DATASET_ID: ClassVar[str] + """The unique identifier for the dataset. Automatically set by __init_subclass__.""" + def __init_subclass__( cls, dataset_id: str | None = None, @@ -312,8 +316,12 @@ def __init__( dataframe: pd.DataFrame | None = None, transforms: list[Transform] | None = None, repeats: int = 1, - ): + ) -> None: if self.__class__.COLUMN_NAMES is not None: + if dataframe is None: + raise ValueError( + f"dataframe cannot be None when COLUMN_NAMES is specified for {self.__class__.__name__}" + ) common = set(self.__class__.COLUMN_NAMES) & set(dataframe.columns) if len(common) != len(self.__class__.COLUMN_NAMES): missing = set(self.__class__.COLUMN_NAMES) - common @@ -325,6 +333,7 @@ def __init__( self.logger = getLogger(__name__) self.transforms = transforms self.repeats = repeats + self.data: list[dict[str, Any]] | None = None @classmethod def load_from_file( @@ -374,6 +383,10 @@ def load( return df = self.dataframe + if df is None: + raise ValueError( + f"Cannot load dataset {self.__class__.__name__}: dataframe is None" + ) transforms = [] if self.transforms is not None: @@ -381,9 +394,9 @@ def load( # If adapter is specified, use it to get transforms, otherwise fallback to use APIType to # get transforms. - if adapter is not None: + if adapter is not None and model_params is not None: transforms.extend(adapter.dataset_transforms(model_params)) - elif api_type is not None: + elif api_type is not None and model_params is not None: transforms.extend(get_transforms_for_api_type(api_type, model_params)) if transforms: @@ -413,9 +426,11 @@ def load_sample(self, index: int) -> Any: IndexError: If index is out of range. IOError: If data cannot be loaded from disk. """ + assert self.data is not None, "Dataset not loaded. Call load() first." return self.data[index] def num_samples(self) -> int: + assert self.data is not None, "Dataset not loaded. Call load() first." return len(self.data) @classmethod @@ -444,11 +459,11 @@ def get_dataloader( class EmptyDataset(Dataset): """Empty dataset to be used as performance dataset when running only accuracy tests.""" - def __init__(self): + def __init__(self) -> None: super().__init__(None) - def load_sample(self, index: int): + def load_sample(self, index: int) -> None: return None - def num_samples(self): + def num_samples(self) -> int: return 0 diff --git a/src/inference_endpoint/dataset_manager/factory.py b/src/inference_endpoint/dataset_manager/factory.py index 2157480d..0998f3f9 100644 --- a/src/inference_endpoint/dataset_manager/factory.py +++ b/src/inference_endpoint/dataset_manager/factory.py @@ -23,7 +23,7 @@ from inference_endpoint.config.schema import Dataset as DatasetConfig from inference_endpoint.dataset_manager.dataset import Dataset, DatasetFormat -from .transforms import ColumnRemap, MakeAdapterCompatible +from .transforms import ColumnRemap, MakeAdapterCompatible, Transform logger = logging.getLogger(__name__) @@ -89,17 +89,21 @@ def create_loader(config: DatasetConfig, num_repeats: int = 1, **kwargs) -> Data f"Dataset {name} is not predefined and no dataset path provided - predefined datasets are: {list(Dataset.PREDEFINED.keys())}" ) + format_enum: DatasetFormat | None = None if file_format is not None: - file_format = DatasetFormat(file_format) + format_enum = DatasetFormat(file_format) - transforms = [] + transforms: list[Transform] = [] if remap is not None: - transforms.append(ColumnRemap(remap)) + transforms.append(ColumnRemap(remap)) # type: ignore[arg-type] transforms.append(MakeAdapterCompatible()) + assert dataset_path is not None + from pathlib import Path + return Dataset.load_from_file( - dataset_path, + Path(dataset_path), transforms=transforms, - format=file_format, + format=format_enum, num_repeats=num_repeats, ) diff --git a/src/inference_endpoint/dataset_manager/predefined/random/__init__.py b/src/inference_endpoint/dataset_manager/predefined/random/__init__.py index 0959bbab..45e096b0 100644 --- a/src/inference_endpoint/dataset_manager/predefined/random/__init__.py +++ b/src/inference_endpoint/dataset_manager/predefined/random/__init__.py @@ -48,20 +48,21 @@ def generate( rng = np.random.default_rng(random_seed) data = [] # Generate the input sequence lengths given the range ratio - input_seq_length = rng.integers( + input_seq_lengths = rng.integers( int(input_seq_length * range_ratio), input_seq_length + 1, num_sequences, ) # Generate the input starts randomly from the vocab size - input_starts = rng.integers(0, tokenizer.vocab_size, num_sequences) + input_starts_array = rng.integers(0, tokenizer.vocab_size, num_sequences) # Generate the input sequences for i in range(num_sequences): # Generate the input sequence by adding the input starts to the input sequence lengths and modding by the vocab size + seq_len = int(input_seq_lengths[i]) + start_val = int(input_starts_array[i]) input_sequence = [ - (input_starts[i] + j) % tokenizer.vocab_size - for j in range(input_seq_length[i]) + (start_val + j) % tokenizer.vocab_size for j in range(seq_len) ] # Decode the input sequence to get the text prompt prompt = tokenizer.decode(input_sequence, add_special_tokens=False) @@ -76,7 +77,7 @@ def generate( { "prompt": prompt, "input_tokens": input_tokens, - "input_seq_length": input_seq_length[i], + "input_seq_length": seq_len, } ) diff --git a/src/inference_endpoint/endpoint_client/cpu_affinity.py b/src/inference_endpoint/endpoint_client/cpu_affinity.py index a280e085..303229c2 100644 --- a/src/inference_endpoint/endpoint_client/cpu_affinity.py +++ b/src/inference_endpoint/endpoint_client/cpu_affinity.py @@ -231,16 +231,17 @@ def core_perf_rank(numa: int, phys: int) -> int: worker_phys_cores.append((primary_numa, phys)) # Sort other NUMA nodes by their best core's performance - other_numas = sorted( + other_numas: list[int] = sorted( (numa for numa in numa_cores if numa != primary_numa), key=lambda n: min(core_perf_rank(n, p) for p in numa_cores[n]), ) # Add cores from other NUMAs (each NUMA's cores sorted by perf) for numa in other_numas: + numa_id = numa # Capture numa in local variable for lambda closure sorted_cores = sorted( numa_cores[numa].keys(), - key=lambda p: core_perf_rank(numa, p), + key=lambda p: core_perf_rank(numa_id, p), ) for phys in sorted_cores: worker_phys_cores.append((numa, phys)) @@ -253,7 +254,7 @@ def core_perf_rank(numa: int, phys: int) -> int: # Log NUMA distribution for workers if worker_phys_cores: - numa_distribution = {} + numa_distribution: dict[int, int] = {} for numa, _ in worker_phys_cores: numa_distribution[numa] = numa_distribution.get(numa, 0) + 1 logger.debug(f"Worker cores by NUMA: {numa_distribution}") @@ -363,7 +364,7 @@ def _parse_cpulist(cpulist_str: str) -> set[int]: """Parse a CPU list string (e.g., '0-3,8-11') into a set.""" if not cpulist_str: return set() - cpus = set() + cpus: set[int] = set() for part in cpulist_str.split(","): if "-" in (p := part.strip()): start, end = p.split("-", 1) diff --git a/src/inference_endpoint/endpoint_client/http.py b/src/inference_endpoint/endpoint_client/http.py index 0bcc9647..1c37454a 100644 --- a/src/inference_endpoint/endpoint_client/http.py +++ b/src/inference_endpoint/endpoint_client/http.py @@ -178,7 +178,12 @@ def _signal_stream_end(self) -> None: # asyncio.Protocol callbacks # ------------------------------------------------------------------------- - def connection_made(self, transport: asyncio.Transport) -> None: + def connection_made(self, transport: asyncio.Transport) -> None: # type: ignore[override] + """Called by asyncio when connection is established. + + Note: We intentionally narrow the transport type from BaseTransport to Transport + for better type safety, as we know we're using TCP transports with specific features. + """ self._transport = transport self._parser = httptools.HttpResponseParser(self) @@ -252,6 +257,8 @@ def on_header(self, name: bytes, value: bytes) -> None: ) def on_headers_complete(self) -> None: + # Parser is always set when this callback is invoked by httptools + assert self._parser is not None self._status_code = self._parser.get_status_code() # Check if server wants to close connection (Connection: close or HTTP/1.0) self._should_close = not self._parser.should_keep_alive() diff --git a/src/inference_endpoint/endpoint_client/http_client.py b/src/inference_endpoint/endpoint_client/http_client.py index 12bf8fd0..57f2e7af 100644 --- a/src/inference_endpoint/endpoint_client/http_client.py +++ b/src/inference_endpoint/endpoint_client/http_client.py @@ -57,17 +57,18 @@ def __init__( self._worker_cycle = cycle(range(self.config.num_workers)) # Use provided loop or create own + self._owns_loop = loop is None + self.loop: asyncio.AbstractEventLoop | None = loop + self._loop_thread: threading.Thread | None = None if loop is None: self.loop = uvloop.new_event_loop() + assert self.loop is not None self._loop_thread = threading.Thread( target=self.loop.run_forever, daemon=True, name=f"HttpClient-{self.client_id}", ) self._loop_thread.start() - else: - self.loop = loop - self._loop_thread = None # Use eager task factory for immediate coroutine execution # Tasks start executing synchronously until first await @@ -75,11 +76,15 @@ def __init__( # NOTE(vir): # CRITICAL for http-client performance # ensures issue() does not get starved by other threads under load - self.loop.set_task_factory(asyncio.eager_task_factory) + assert self.loop is not None + self.loop.set_task_factory(asyncio.eager_task_factory) # type: ignore[arg-type] # Initialize on event loop asyncio.run_coroutine_threadsafe(self._initialize(), self.loop).result() + assert self.config.adapter is not None + assert self.config.accumulator is not None + assert self.config.worker_pool_transport is not None logger.info( f"EndpointClient initialized with num_workers={self.config.num_workers}, " f"endpoints={self.config.endpoint_urls}, " @@ -95,6 +100,7 @@ async def _initialize(self) -> None: self._dropped_requests: int = 0 # WorkerManager creates and owns all transports + assert self.loop is not None self.worker_manager = WorkerManager(self.config, self.loop) await self.worker_manager.initialize() self.pool = self.worker_manager.pool_transport @@ -133,8 +139,8 @@ async def shutdown(self) -> None: # Shutdown workers await self.worker_manager.shutdown() - # Stop event loop if we own it - if self._loop_thread is not None: + # Stop event loop if we own it (scheduled to run after this coroutine completes) + if self._owns_loop and self.loop is not None: self.loop.call_soon(self.loop.stop) if self._dropped_requests > 0: @@ -157,10 +163,12 @@ class HTTPEndpointClient(AsyncHttpEndpointClient): def issue(self, query: Query) -> None: # type: ignore[override] """Issue query.""" # Schedule on event loop thread + assert self.loop is not None self.loop.call_soon_threadsafe( lambda: super(HTTPEndpointClient, self).issue(query) ) def shutdown(self) -> None: # type: ignore[override] - """Sync shutdown.""" + """Sync shutdown wrapper - blocks until base class async shutdown completes.""" + assert self.loop is not None asyncio.run_coroutine_threadsafe(super().shutdown(), self.loop).result() diff --git a/src/inference_endpoint/endpoint_client/http_sample_issuer.py b/src/inference_endpoint/endpoint_client/http_sample_issuer.py index c30379ee..0f499afa 100644 --- a/src/inference_endpoint/endpoint_client/http_sample_issuer.py +++ b/src/inference_endpoint/endpoint_client/http_sample_issuer.py @@ -51,6 +51,7 @@ def __init__( self.http_client = http_client # Start response handler task to route completed responses back to SampleEventHandler + assert self.http_client.loop is not None self._response_task = asyncio.run_coroutine_threadsafe( self._handle_responses(), self.http_client.loop ) diff --git a/src/inference_endpoint/endpoint_client/transport/zmq/transport.py b/src/inference_endpoint/endpoint_client/transport/zmq/transport.py index 979d938b..4cb4c64f 100644 --- a/src/inference_endpoint/endpoint_client/transport/zmq/transport.py +++ b/src/inference_endpoint/endpoint_client/transport/zmq/transport.py @@ -398,7 +398,19 @@ def _create_receiver( message_type: type | None = None, bind: bool = False, ) -> _ZmqReceiverTransport: - """Create a ZMQ receiver transport.""" + """Create a ZMQ receiver transport. + + Args: + loop: Event loop for transport registration. + address: ZMQ address (e.g., "ipc:///tmp/socket"). + context: ZMQ context. + config: Socket configuration. + message_type: Type hint for msgspec decoder. Can be a single type, Union type, or None. + bind: Whether to bind (True) or connect (False). + + Returns: + Configured receiver transport. + """ sock = context.socket(zmq.PULL) sock.setsockopt(zmq.LINGER, config.linger) sock.setsockopt(zmq.RCVHWM, config.high_water_mark) @@ -410,7 +422,7 @@ def _create_receiver( sock.connect(address) decoder = ( - msgspec.msgpack.Decoder(type=message_type) + msgspec.msgpack.Decoder(type=message_type) # type: ignore[arg-type] if message_type else msgspec.msgpack.Decoder() ) @@ -604,7 +616,7 @@ def __init__( self._response_addr, self._context, config, - QueryResult | StreamChunk, + QueryResult | StreamChunk, # type: ignore[arg-type] bind=True, ) self._readiness_receiver = _create_receiver( diff --git a/src/inference_endpoint/endpoint_client/worker.py b/src/inference_endpoint/endpoint_client/worker.py index 391e3b0e..7baa4568 100644 --- a/src/inference_endpoint/endpoint_client/worker.py +++ b/src/inference_endpoint/endpoint_client/worker.py @@ -29,6 +29,7 @@ from urllib.parse import urlparse from inference_endpoint.core.types import Query, QueryResult +from inference_endpoint.endpoint_client.adapter_protocol import HttpRequestAdapter from inference_endpoint.endpoint_client.config import HTTPClientConfig from inference_endpoint.endpoint_client.http import ( ConnectionPool, @@ -164,7 +165,8 @@ def __init__( self._active_tasks: set[asyncio.Task] = set() # Use adapter type from config - self._adapter = self.http_config.adapter + assert self.http_config.adapter is not None + self._adapter: type[HttpRequestAdapter] = self.http_config.adapter async def run(self) -> None: """Main worker loop - pull requests, execute, push responses.""" @@ -175,7 +177,8 @@ async def run(self) -> None: # Use eager task factory for immediate coroutine execution # Tasks start executing synchronously until first await # NOTE(vir): CRITICAL for minimizing TFB/TTFT - self._loop.set_task_factory(asyncio.eager_task_factory) + assert self._loop is not None + self._loop.set_task_factory(asyncio.eager_task_factory) # type: ignore[arg-type] # Initialize HTTP template from URL components self._http_template = HttpRequestTemplate.from_url( @@ -249,7 +252,9 @@ async def run(self) -> None: # Run main processing loop if self.http_config.record_worker_events: - worker_db_name = f"worker_report_{self.worker_id}_{os.getpid()}" + pid = os.getpid() + worker_db_name = f"worker_report_{self.worker_id}_{pid}" + assert self.http_config.event_logs_dir is not None report_path = self.http_config.event_logs_dir / f"{worker_db_name}.csv" with EventRecorder(session_id=worker_db_name) as event_recorder: @@ -312,6 +317,7 @@ async def _run_main_loop(self) -> None: continue # Process response asynchronously + assert self._loop is not None task = self._loop.create_task(self._process_response(prepared)) # Keep task alive to prevent GC @@ -333,6 +339,7 @@ def _prepare_request(self, query: Query) -> InFlightRequest: is_streaming = query.data.get("stream", False) # Build complete HTTP request bytes + assert self._http_template is not None http_bytes = self._http_template.build_request( body_bytes, is_streaming, @@ -364,6 +371,7 @@ async def _fire_request(self, req: InFlightRequest) -> bool: try: # Acquire connection from pool + assert self._pool is not None conn = await self._pool.acquire() # Write request bytes directly to transport @@ -384,13 +392,15 @@ async def _process_response(self, req: InFlightRequest) -> None: """Process response for a fired request.""" try: conn = req.connection + assert conn is not None, "Connection should be set by _fire_request" # Await headers and handle error status status_code, _ = await conn.protocol.read_headers() if status_code != 200: error_body = await conn.protocol.read_body() # Release connection early - done with socket I/O - self._pool.release(req.connection) + assert self._pool is not None + self._pool.release(conn) req.connection = None await self._handle_error( req.query_id, @@ -411,6 +421,7 @@ async def _process_response(self, req: InFlightRequest) -> None: finally: # Release connection back to pool if not already released if req.connection: + assert self._pool is not None self._pool.release(req.connection) req.connection = None @@ -424,20 +435,25 @@ async def _process_response(self, req: InFlightRequest) -> None: ) # Clean up task reference - self._active_tasks.discard(asyncio.current_task()) + current_task = asyncio.current_task() + if current_task is not None: + self._active_tasks.discard(current_task) @profile async def _handle_streaming_body(self, req: InFlightRequest) -> None: """Handle streaming (SSE) response body.""" conn = req.connection + assert conn is not None query_id = req.query_id # Create accumulator for streaming response + assert self.http_config.accumulator is not None accumulator = self.http_config.accumulator( query_id, self.http_config.stream_all_chunks ) # Process SSE stream - yields batches of chunks + assert self._responses is not None async for chunk_batch in self._iter_sse_lines(conn): for delta in chunk_batch: if stream_chunk := accumulator.add_chunk(delta): @@ -452,6 +468,7 @@ async def _handle_streaming_body(self, req: InFlightRequest) -> None: ) # Release connection early - done with socket I/O + assert self._pool is not None self._pool.release(conn) req.connection = None @@ -469,12 +486,14 @@ async def _handle_streaming_body(self, req: InFlightRequest) -> None: async def _handle_non_streaming_body(self, req: InFlightRequest) -> None: """Handle non-streaming response body.""" conn = req.connection + assert conn is not None query_id = req.query_id # Read entire response body response_bytes = await conn.protocol.read_body() # Release connection early - done with socket I/O + assert self._pool is not None self._pool.release(conn) req.connection = None @@ -482,6 +501,7 @@ async def _handle_non_streaming_body(self, req: InFlightRequest) -> None: result = self._adapter.decode_response(response_bytes, query_id) # Send result back to main rank + assert self._responses is not None self._responses.send(result) if self.http_config.record_worker_events: EventRecorder.record_event( diff --git a/src/inference_endpoint/endpoint_client/worker_manager.py b/src/inference_endpoint/endpoint_client/worker_manager.py index d26d3cb7..a6e3b29f 100644 --- a/src/inference_endpoint/endpoint_client/worker_manager.py +++ b/src/inference_endpoint/endpoint_client/worker_manager.py @@ -50,6 +50,8 @@ def __init__( self.http_config = http_config # Create pool transport via factory + # worker_pool_transport is guaranteed to be set by HTTPClientConfig.__post_init__ + assert http_config.worker_pool_transport is not None self.pool_transport: WorkerPoolTransport = ( http_config.worker_pool_transport.create( loop, num_workers=http_config.num_workers @@ -71,6 +73,9 @@ async def initialize(self) -> None: for i in range(self.http_config.num_workers): process = self._spawn_worker(i, connector) self.workers.append(process) + assert ( + process.pid is not None + ), "Worker process should have a PID after spawning" self.worker_pids[i] = process.pid # Apply CPU affinity after all workers are started diff --git a/src/inference_endpoint/evaluation/extractor.py b/src/inference_endpoint/evaluation/extractor.py index 4bb17c5b..99d07db4 100644 --- a/src/inference_endpoint/evaluation/extractor.py +++ b/src/inference_endpoint/evaluation/extractor.py @@ -31,6 +31,9 @@ class Extractor(ABC): # This allows registering new extractors that can be instantiated via config/lookup. PREDEFINED: ClassVar[dict[str, type["Extractor"]]] = {} + EXTRACTOR_ID: ClassVar[str] + """The unique identifier for the extractor. Automatically set by __init_subclass__.""" + def __init_subclass__( cls, extractor_id: str | None = None, diff --git a/src/inference_endpoint/evaluation/livecodebench/lcb_serve.py b/src/inference_endpoint/evaluation/livecodebench/lcb_serve.py index c272ca9a..71fc6bd2 100644 --- a/src/inference_endpoint/evaluation/livecodebench/lcb_serve.py +++ b/src/inference_endpoint/evaluation/livecodebench/lcb_serve.py @@ -77,7 +77,9 @@ def execute_code_single(test_suite_json: str, code: str, timeout_sec: int = 60): return fixed_types, metadata -def execute_code_single_suppressed_errors(*args, resp_buffer: list = None, **kwargs): +def execute_code_single_suppressed_errors( + *args, resp_buffer: list | None = None, **kwargs +): """Wrapper around execute code so that all errors are resurfaced as failed tests""" try: res, metadata = execute_code_single(*args, **kwargs) @@ -264,9 +266,9 @@ def __call__( dict[str, list[bool]]: Dictionary mapping question IDs to lists of boolean results for each code sample. """ # Create results dict with the expected size - results = {} + results: dict[str, list[bool]] = {} for qid, test_codes in zip(question_ids, codes, strict=False): - results[qid] = [None] * len(test_codes) + results[qid] = [False] * len(test_codes) futures = {} with ProcessPoolExecutor(max_workers=self.n_lcb_workers) as executor: @@ -380,7 +382,7 @@ def __init__( qid ) # Accessing will populate the cache - def cache_info(self) -> dict[str, int]: + def cache_info(self): """Returns the cache information for the test loader.""" return self.test_loader.cache_info() @@ -406,6 +408,7 @@ def evaluate( KeyError: If any question_id is not found in the loaded test suites. """ # Validate all question IDs exist in test suites + assert self.df is not None, "Dataset not loaded" invalid_ids = [ qid for qid in codes_dict.keys() if qid not in self.df["question_id"].values ] diff --git a/src/inference_endpoint/evaluation/livecodebench/run_lcb_tests.py b/src/inference_endpoint/evaluation/livecodebench/run_lcb_tests.py index 9fbe322b..16798f6f 100644 --- a/src/inference_endpoint/evaluation/livecodebench/run_lcb_tests.py +++ b/src/inference_endpoint/evaluation/livecodebench/run_lcb_tests.py @@ -296,7 +296,7 @@ def grade_call_based( all_outputs = [json.loads(output) for output in all_outputs] - total_execution = 0 + total_execution = 0.0 all_results = [] for gt_inp, gt_out in zip(all_inputs, all_outputs, strict=False): signal.alarm(timeout) @@ -377,7 +377,7 @@ def grade_stdio( return all_results = [] - total_execution_time = 0 + total_execution_time = 0.0 for gt_inp, gt_out in zip(all_inputs, all_outputs, strict=False): signal.alarm(timeout) faulthandler.enable() diff --git a/src/inference_endpoint/evaluation/scoring.py b/src/inference_endpoint/evaluation/scoring.py index ef1a32ed..1bf4e99e 100644 --- a/src/inference_endpoint/evaluation/scoring.py +++ b/src/inference_endpoint/evaluation/scoring.py @@ -47,6 +47,7 @@ class Scorer(ABC): """ PREDEFINED: ClassVar[dict[str, type["Scorer"]]] = {} + SCORER_ID: ClassVar[str] def __init_subclass__( cls, @@ -92,7 +93,7 @@ def __init__( dataset: Dataset, report_dir: os.PathLike, extractor: type[Extractor] | None = None, - ground_truth_column: str = "ground_truth", + ground_truth_column: str | None = "ground_truth", ): self.dataset = dataset self.report_dir = Path(report_dir) @@ -146,11 +147,12 @@ def match_sample_index(self, row: pd.Series) -> pd.Series: def score_single_sample(self, value: str, ground_truth: str) -> float: raise NotImplementedError - def score(self) -> tuple[float, int]: + def score(self) -> tuple[float | None, int]: """Scores the dataset and returns the mean score and the number of repeats. Returns: - tuple[float, int]: The mean score and the number of repeats. + tuple[float | None, int]: The mean score and the number of repeats. + Returns None as the score if evaluation fails. """ df = self.get_outputs() @@ -168,6 +170,9 @@ def score(self) -> tuple[float, int]: # Get ground truths order = df["sample_index"].to_numpy() + assert ( + self.dataset.dataframe is not None + ), f"Dataset {self.dataset} has no dataframe loaded" assert ( self.ground_truth_column in self.dataset.dataframe.columns ), f"Ground truth column {self.ground_truth_column} not found in dataset {self.dataset}" @@ -268,6 +273,9 @@ def score(self) -> tuple[float, int]: empirical = df["output"].tolist() order = df["sample_index"].to_numpy().astype(int) + assert ( + self.dataset.dataframe is not None + ), f"Dataset {self.dataset} has no dataframe loaded" assert ( self.ground_truth_column in self.dataset.dataframe.columns ), f"Ground truth column {self.ground_truth_column} not found in dataset {self.dataset}" @@ -535,6 +543,9 @@ def _evaluate_via_subprocess(self, df: pd.DataFrame) -> float | None: # Collect stdout while displaying it character-by-character to support # progress bars that use carriage returns + if process.stdout is None: + raise RuntimeError("Failed to capture subprocess stdout") + stdout_buffer = [] while True: char = process.stdout.read(1) @@ -589,13 +600,19 @@ def score(self) -> tuple[float | None, int]: df = df.apply(self.match_sample_index, axis=1) # Get question IDs + assert ( + self.dataset.dataframe is not None + ), f"Dataset {self.dataset} has no dataframe loaded" + def get_question_id(sample_index: int) -> str: + assert self.dataset.dataframe is not None return self.dataset.dataframe.iloc[sample_index][self.question_id_column] df["question_id"] = df["sample_index"].apply(get_question_id) # Extract code from outputs with default value for failed extractions # Use a comment that will fail all tests instead of None to maintain uniform list lengths + assert self.extractor is not None, "Extractor must be set for code extraction" df["extracted_code"] = df["output"].apply( lambda x: self.extractor.extract(x, default="# FAILED TO EXTRACT CODE") ) @@ -620,7 +637,7 @@ def get_question_id(sample_index: int) -> str: print( f"Server evaluated {total_samples} samples but returned an empty summary" ) - return None + return None, n_repeats total_passed = sum( sum(code_passed) for code_passed in per_problem_results.values() diff --git a/src/inference_endpoint/load_generator/load_generator.py b/src/inference_endpoint/load_generator/load_generator.py index 26525dc1..7047e483 100644 --- a/src/inference_endpoint/load_generator/load_generator.py +++ b/src/inference_endpoint/load_generator/load_generator.py @@ -135,10 +135,10 @@ def __init__( self.sample_issuer = sample_issuer self.dataloader = dataloader self.name = name - self.uuid_to_index_map = {} + self.uuid_to_index_map: dict[str, int] = {} @abstractmethod - def __next__(self) -> tuple[Sample, int]: + def __next__(self) -> IssuedSample: """Issue the next sample according to the load generation strategy. This method should: @@ -152,9 +152,7 @@ def __next__(self) -> tuple[Sample, int]: It should only return AFTER the sample has been issued. Returns: - Tuple of (sample, timestamp_ns): - - sample: The Sample object that was issued - - timestamp_ns: Monotonic nanosecond timestamp when issued + IssuedSample object containing the sample, index, and issue timestamp. Raises: StopIteration: When all samples have been issued. @@ -287,7 +285,8 @@ def __next__(self) -> IssuedSample: StopIteration: When scheduler has no more samples to issue. """ # Let raised StopIteration be propagated up the stack - s_idx, delay_ns = next(self._iterator) + # Ignore mypy error complaining that self._iterator maybe None + s_idx, delay_ns = next(self._iterator) # type: ignore[call-overload] # Data loading is not timed for Time-to-Token metrics. It is assumed that the # hypothetical user would have put the data into memory available for a network diff --git a/src/inference_endpoint/load_generator/sample.py b/src/inference_endpoint/load_generator/sample.py index bcfe89b7..18d5e032 100644 --- a/src/inference_endpoint/load_generator/sample.py +++ b/src/inference_endpoint/load_generator/sample.py @@ -107,7 +107,9 @@ def __init__(self): self.complete_hooks = [] def register_hook( - self, event_type: SampleEvent, hook: Callable[[StreamChunk | QueryResult], None] + self, + event_type: SampleEvent, + hook: Callable[[StreamChunk], None] | Callable[[QueryResult], None], ) -> None: if event_type == SampleEvent.FIRST_CHUNK: self.first_chunk_hooks.append(hook) diff --git a/src/inference_endpoint/load_generator/scheduler.py b/src/inference_endpoint/load_generator/scheduler.py index 997a354a..ae691d09 100644 --- a/src/inference_endpoint/load_generator/scheduler.py +++ b/src/inference_endpoint/load_generator/scheduler.py @@ -169,7 +169,7 @@ def next_sample_index(self) -> int: def uniform_delay_fn( - max_delay_ns: int = 0, rng: random.Random = random + max_delay_ns: int = 0, rng: random.Random | None = None ) -> Callable[[], float]: """Create a uniform delay function for schedulers. @@ -184,6 +184,7 @@ def uniform_delay_fn( Returns: Function that returns delay in nanoseconds (float). """ + rng = rng or random.Random() def _fn(): if max_delay_ns == 0: @@ -194,7 +195,7 @@ def _fn(): def poisson_delay_fn( - expected_queries_per_second: float, rng: random.Random = random + expected_queries_per_second: float, rng: random.Random | None = None ) -> Callable[[], float]: """Create a Poisson-distributed delay function for realistic online benchmarking. @@ -213,6 +214,7 @@ def poisson_delay_fn( Returns: Function that returns delay in nanoseconds (float). """ + rng = rng or random.Random() queries_per_ns = expected_queries_per_second / 1e9 def _fn(): @@ -278,7 +280,7 @@ def __init__( rng=self.runtime_settings.rng_sample_index, ) ) - self.delay_fn = None # Subclasses must set this + self.delay_fn: Callable[[], int] | None = None # Subclasses must set this def __iter__(self): """Iterate over (sample_index, delay_ns) pairs. @@ -375,7 +377,7 @@ class ConcurrencyScheduler(Scheduler, load_pattern=LoadPatternType.CONCURRENCY): def __init__(self, runtime_settings: RuntimeSettings, sample_order_cls): super().__init__(runtime_settings, sample_order_cls) - + assert runtime_settings.load_pattern is not None target_concurrency = runtime_settings.load_pattern.target_concurrency if target_concurrency is None or target_concurrency <= 0: raise ValueError( @@ -390,7 +392,7 @@ def __init__(self, runtime_settings: RuntimeSettings, sample_order_cls): # Register completion hook - free up slot when query completes SampleEventHandler.register_hook(SampleEvent.COMPLETE, self._release_slot) - # Unused (required by Scheduler interface) + # Unused (required by Scheduler interface) - returns 0 delay self.delay_fn = lambda: 0 def _release_slot(self, result=None): diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index 88943e06..b8c6394f 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -51,7 +51,7 @@ def __init__( # EventRecorder will set this when all samples complete, helps avoid busy-waiting self.end_event = threading.Event() - self.thread = None + self.thread: threading.Thread | None = None # CPython GIL provides atomic boolean writes, no need for threading.Event() self.stop_requested = False @@ -62,7 +62,7 @@ def __init__( # Will be populated after the test finishes by _run_test self.report = None - self.sample_uuid_map = None + self.sample_uuid_map: dict[str, dict[str, int]] | None = None @property def is_running(self): @@ -147,8 +147,8 @@ def _run_test( if tokenizer_override is not None: tokenizer = tokenizer_override if has_model: - model = self.runtime_settings.model - if tokenizer is None: + model = getattr(self.runtime_settings, "model", None) + if tokenizer is None and model is not None: try: tokenizer = AutoTokenizer.from_pretrained( model if isinstance(model, str) else model.name @@ -184,7 +184,7 @@ def _run_test( report.to_json(save_to=Path(report_dir) / "result_summary.json") # Dump runtime settings to report directory - rt_settings_data = { + rt_settings_data: dict[str, int | str | None] = { "min_duration_ms": self.runtime_settings.min_duration_ms, "max_duration_ms": self.runtime_settings.max_duration_ms, "n_samples_from_dataset": self.runtime_settings.n_samples_from_dataset, @@ -196,9 +196,9 @@ def _run_test( # to retrieve the seed values. The best way to do this is probably a custom random.Random # class that stores the original seed as a read-only property, and unable to set the seed # after initialization. - if has_model: + if has_model and model is not None: rt_settings_data["model"] = ( - model if isinstance(model, str) else model.name + model if isinstance(model, str) else str(model.name) ) # TODO: After Zhihan's MR is merged, grab the scheduler class and other LG init settings @@ -218,12 +218,15 @@ def _run_test( if dump_events_log: reporter.dump_to_json(Path(report_dir) / "events.jsonl") - # Dump report to text file - report_path = report_dir / "report.txt" + # Display report to console report.display(fn=print, summary_only=True) - with open(report_path, "w") as f: - report.display(fn=f.write, summary_only=False, newline="\n") - logger.info(f"Report saved to {report_path}") + + # Dump report to text file if report_dir is provided + if report_dir: + report_path = Path(report_dir) / "report.txt" + with open(report_path, "w") as f: + report.display(fn=f.write, summary_only=False, newline="\n") + logger.info(f"Report saved to {report_path}") def wait_for_test_end(self, timeout: float | None = None) -> bool: """ @@ -235,6 +238,8 @@ def wait_for_test_end(self, timeout: float | None = None) -> bool: Returns: bool: True if the test thread has completed, False if it timed out. """ + if not self.thread: + return False self.thread.join(timeout=timeout) return not self.thread.is_alive() @@ -276,7 +281,7 @@ def start( The new BenchmarkSession. """ session = cls(runtime_settings, session_id=name) - load_generator = load_generator_cls(sample_issuer, dataset, scheduler, *args) + load_generator = load_generator_cls(sample_issuer, dataset, scheduler, *args) # type: ignore[arg-type] # Create accuracy test generators accuracy_test_generators = None @@ -293,7 +298,7 @@ def start( metric_target=runtime_settings.metric_target, reported_metrics=runtime_settings.reported_metrics, min_duration_ms=0, - max_duration_ms=None, + max_duration_ms=None, # type: ignore[arg-type] n_samples_from_dataset=ds.num_samples(), n_samples_to_issue=ds.num_samples() * ds.repeats, min_sample_count=ds.num_samples() * ds.repeats, @@ -306,7 +311,10 @@ def start( ) accuracy_test_generators[ds_name] = load_generator_cls( - sample_issuer, ds, acc_sched, *args + sample_issuer, + ds, + acc_sched, # type: ignore[arg-type] + *args, ) session.thread = threading.Thread( diff --git a/src/inference_endpoint/main.py b/src/inference_endpoint/main.py index 054798a5..c9f8853c 100644 --- a/src/inference_endpoint/main.py +++ b/src/inference_endpoint/main.py @@ -48,7 +48,7 @@ def run() -> None: # Use eager task factory for immediate coroutine execution # Tasks start executing synchronously until first await with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner: - runner.get_loop().set_task_factory(asyncio.eager_task_factory) + runner.get_loop().set_task_factory(asyncio.eager_task_factory) # type: ignore[arg-type] runner.run(main()) except KeyboardInterrupt: logger.info("Application interrupted by user") diff --git a/src/inference_endpoint/metrics/metric.py b/src/inference_endpoint/metrics/metric.py index 01fa4e62..c4ea1f93 100644 --- a/src/inference_endpoint/metrics/metric.py +++ b/src/inference_endpoint/metrics/metric.py @@ -39,8 +39,8 @@ def is_valid(self, measurement: float) -> bool: class Throughput(Metric): REL_TOL = 0.1 # Relative tolerance for throughput - def __init__(self, target: float): - super().__init__(target) + def __init__(self, target_qps: float): + super().__init__(target_qps) def is_valid(self, measurement: float) -> bool: return math.isclose(measurement, self.target, rel_tol=self.REL_TOL) @@ -49,7 +49,9 @@ def is_valid(self, measurement: float) -> bool: class QueryLatency(Metric): REL_TOL = 0.1 # Relative tolerance for query latency - def __init__(self, target_latency_ms: float = None, target_qps: float = None): + def __init__( + self, target_latency_ms: float | None = None, target_qps: float | None = None + ): """ Args: target_latency_ms: The target latency in milliseconds @@ -68,7 +70,8 @@ def is_valid(self, measurement: float) -> bool: class TTFT(Metric): - def __init__(self, max_ttft_latency_ms: float): + def __init__(self, max_ttft_latency_ms: float | None): + assert max_ttft_latency_ms is not None super().__init__(max_ttft_latency_ms) def is_valid(self, measurement: float) -> bool: @@ -76,7 +79,8 @@ def is_valid(self, measurement: float) -> bool: class TPOT(Metric): - def __init__(self, max_tpot_latency_ms: float): + def __init__(self, max_tpot_latency_ms: float | None): + assert max_tpot_latency_ms is not None super().__init__(max_tpot_latency_ms) def is_valid(self, measurement: float) -> bool: diff --git a/src/inference_endpoint/metrics/recorder.py b/src/inference_endpoint/metrics/recorder.py index 0de95b23..b1dec159 100644 --- a/src/inference_endpoint/metrics/recorder.py +++ b/src/inference_endpoint/metrics/recorder.py @@ -101,10 +101,10 @@ def to_insert_params(self) -> tuple[str, str, int, bytes]: ) -def register_cleanup(file_path: str): +def register_cleanup(file_path: Path): if multiprocessing.parent_process() is not None: return - atexit.register(partial(Path(file_path).unlink, missing_ok=True)) + atexit.register(partial(file_path.unlink, missing_ok=True)) logger.debug(f"Registered at-exit cleanup for {file_path}") @@ -163,7 +163,7 @@ def __init__( if self.connection_name not in EventRecorder._created_session_dbs: register_cleanup(self.connection_name) - EventRecorder._created_session_dbs.add(self.connection_name) + EventRecorder._created_session_dbs.add(str(self.connection_name)) if not Path(self.connection_name).parent.exists(): raise FileNotFoundError( @@ -475,7 +475,7 @@ def record_exception( EventRecorder.record_event( SessionEvent.ERROR, time.monotonic_ns(), - sample_uuid=sample_uuid, + sample_uuid=sample_uuid or "", data={ "error_type": exc_value.__class__.__name__, "error_message": str(exc_value), diff --git a/src/inference_endpoint/metrics/reporter.py b/src/inference_endpoint/metrics/reporter.py index 2d613dd4..bdf5907b 100644 --- a/src/inference_endpoint/metrics/reporter.py +++ b/src/inference_endpoint/metrics/reporter.py @@ -152,7 +152,7 @@ def __len__(self) -> int: if self.repeats is None: return len(self.rows) else: - return int(self.repeats.sum()) + return int(sum(self.repeats)) def filter_uuid(self, uuid: str, only_first: bool = False) -> Any: """Returns the values for the given sample UUID. diff --git a/src/inference_endpoint/openai/types.py b/src/inference_endpoint/openai/types.py index 8f59e969..868a44d9 100644 --- a/src/inference_endpoint/openai/types.py +++ b/src/inference_endpoint/openai/types.py @@ -49,7 +49,7 @@ class SSEMessage(msgspec.Struct): # ============================================================================ -class ChatMessage(msgspec.Struct, kw_only=True, omit_defaults=True): +class ChatMessage(msgspec.Struct, kw_only=True, omit_defaults=True): # type: ignore[call-arg] """Chat message in OpenAI format.""" role: str @@ -57,7 +57,7 @@ class ChatMessage(msgspec.Struct, kw_only=True, omit_defaults=True): name: str | None = None -class ChatCompletionRequest(msgspec.Struct, kw_only=True, omit_defaults=True): +class ChatCompletionRequest(msgspec.Struct, kw_only=True, omit_defaults=True): # type: ignore[call-arg] """OpenAI chat completion request.""" model: str @@ -76,7 +76,7 @@ class ChatCompletionRequest(msgspec.Struct, kw_only=True, omit_defaults=True): user: str | None = None -class ChatCompletionResponseMessage(msgspec.Struct, kw_only=True, omit_defaults=True): +class ChatCompletionResponseMessage(msgspec.Struct, kw_only=True, omit_defaults=True): # type: ignore[call-arg] """Response message from OpenAI.""" role: str @@ -84,7 +84,7 @@ class ChatCompletionResponseMessage(msgspec.Struct, kw_only=True, omit_defaults= refusal: str | None -class ChatCompletionChoice(msgspec.Struct, kw_only=True, omit_defaults=True): +class ChatCompletionChoice(msgspec.Struct, kw_only=True, omit_defaults=True): # type: ignore[call-arg] """A single choice in the completion response.""" index: int @@ -92,7 +92,7 @@ class ChatCompletionChoice(msgspec.Struct, kw_only=True, omit_defaults=True): finish_reason: str | None -class CompletionUsage(msgspec.Struct, kw_only=True, omit_defaults=True): +class CompletionUsage(msgspec.Struct, kw_only=True, omit_defaults=True): # type: ignore[call-arg] """Token usage statistics.""" prompt_tokens: int @@ -100,7 +100,7 @@ class CompletionUsage(msgspec.Struct, kw_only=True, omit_defaults=True): total_tokens: int -class ChatCompletionResponse(msgspec.Struct, kw_only=True, omit_defaults=True): +class ChatCompletionResponse(msgspec.Struct, kw_only=True, omit_defaults=True): # type: ignore[call-arg] """OpenAI chat completion response (msgspec version).""" id: str diff --git a/src/inference_endpoint/sglang/types.py b/src/inference_endpoint/sglang/types.py index e5d2fd04..65dc2662 100644 --- a/src/inference_endpoint/sglang/types.py +++ b/src/inference_endpoint/sglang/types.py @@ -26,7 +26,7 @@ # ============================================================================ -class SamplingParams(msgspec.Struct, kw_only=True, omit_defaults=True): +class SamplingParams(msgspec.Struct, kw_only=True, omit_defaults=True): # type: ignore[call-arg] max_new_tokens: int = 32768 """int: Maximum number of tokens to generate per request (1-32768)""" @@ -40,13 +40,13 @@ class SamplingParams(msgspec.Struct, kw_only=True, omit_defaults=True): """float: Top-p/nucleus sampling (cumulative probability threshold). 0.0-1.0, typically 1.0 for no filterin""" -class SGLangGenerateRequest(msgspec.Struct, kw_only=True, omit_defaults=True): +class SGLangGenerateRequest(msgspec.Struct, kw_only=True, omit_defaults=True): # type: ignore[call-arg] input_ids: list[int] sampling_params: SamplingParams stream: bool -class MetaInfo(msgspec.Struct, kw_only=True, omit_defaults=True): +class MetaInfo(msgspec.Struct, kw_only=True, omit_defaults=True): # type: ignore[call-arg] id: str finish_reason: dict[str, Any] prompt_tokens: int @@ -57,7 +57,7 @@ class MetaInfo(msgspec.Struct, kw_only=True, omit_defaults=True): e2e_latency: float -class SGLangGenerateResponse(msgspec.Struct, kw_only=True, omit_defaults=True): +class SGLangGenerateResponse(msgspec.Struct, kw_only=True, omit_defaults=True): # type: ignore[call-arg] text: str output_ids: list[int] meta_info: MetaInfo diff --git a/src/inference_endpoint/testing/echo_server.py b/src/inference_endpoint/testing/echo_server.py index bb8a1bda..cfcef3b8 100644 --- a/src/inference_endpoint/testing/echo_server.py +++ b/src/inference_endpoint/testing/echo_server.py @@ -249,10 +249,13 @@ async def _handle_echo_chat_completions_request( raw_payload = await request.text() json_payload = json.loads(raw_payload) completion_request = CreateChatCompletionRequest(**json_payload) + raw_request = "" if completion_request.messages and len(completion_request.messages) > 0: for message in completion_request.messages: if str(message.root.role.value) == "user": - raw_request = message.root.content + content = message.root.content + # Convert content to string - handle various content types + raw_request = str(content) if content is not None else "" break else: raise ValueError("Request must contain at least one message") diff --git a/src/inference_endpoint/utils/__init__.py b/src/inference_endpoint/utils/__init__.py index e0d3badd..9492fb02 100644 --- a/src/inference_endpoint/utils/__init__.py +++ b/src/inference_endpoint/utils/__init__.py @@ -64,7 +64,7 @@ def byte_quantity_to_str( while n_bytes >= 1024: if suffixes[suffix_idx] == max_unit or suffix_idx >= len(suffixes) - 1: break - n_bytes /= 1024 + n_bytes //= 1024 suffix_idx += 1 n_bytes = int(n_bytes) suffix = suffixes[suffix_idx] diff --git a/src/inference_endpoint/utils/dataset_utils.py b/src/inference_endpoint/utils/dataset_utils.py index 35322270..546346b6 100644 --- a/src/inference_endpoint/utils/dataset_utils.py +++ b/src/inference_endpoint/utils/dataset_utils.py @@ -32,7 +32,7 @@ def tokenizer_stats( if end_index == -1: end_index = tokenizer.vocab_size token_to_text = {} # dictionary from token ids to text - token_leng_counts = {} # histogram of token lengths + token_leng_counts: dict[int, int] = {} # histogram of token lengths for i in range(start_index, end_index): token_to_text[i] = tokenizer.decode([i]) token_leng_counts[len(token_to_text[i])] = 1 + token_leng_counts.get( diff --git a/src/inference_endpoint/utils/logging.py b/src/inference_endpoint/utils/logging.py index 8677dd75..aae4650e 100644 --- a/src/inference_endpoint/utils/logging.py +++ b/src/inference_endpoint/utils/logging.py @@ -22,6 +22,7 @@ import logging import os import sys +from typing import Literal from colorama import Fore, Style from colorama import init as _colorama_init @@ -49,7 +50,7 @@ def __init__( self, fmt: str | None = None, datefmt: str | None = None, - style: str = "%", + style: Literal["%", "{", "$"] = "%", use_color: bool = False, ): """Initialize the formatter. diff --git a/tests/futures_client.py b/tests/futures_client.py index 989bfa63..86d21fcd 100644 --- a/tests/futures_client.py +++ b/tests/futures_client.py @@ -41,12 +41,16 @@ def __init__( # Start response handler on client's loop self._pending: dict[str | int, concurrent.futures.Future] = {} + assert ( + self.loop is not None + ), "Client loop should be initialized by parent __init__" self._handler_future = asyncio.run_coroutine_threadsafe( self._handle_responses(), self.loop ) self._is_shutting_down = False - def issue(self, query: Query) -> concurrent.futures.Future[QueryResult]: + # TODO (vir): fix this type ignore since the base class doesn't have a return value + def issue(self, query: Query) -> concurrent.futures.Future[QueryResult]: # type: ignore[override] """Issue query and return a future for the result.""" if self._is_shutting_down: raise RuntimeError("Cannot issue query: client is shutting down") diff --git a/tests/integration/test_end_to_end_oracle.py b/tests/integration/test_end_to_end_oracle.py index cbf5eb35..ba1cd2c9 100644 --- a/tests/integration/test_end_to_end_oracle.py +++ b/tests/integration/test_end_to_end_oracle.py @@ -15,6 +15,7 @@ import logging import random +from pathlib import Path from urllib.parse import urljoin import pytest @@ -40,7 +41,7 @@ class DeepSeekR1SampleIssuer(HttpClientSampleIssuer): - def __init__(self, tmp_path: str, url: str): + def __init__(self, tmp_path: Path, url: str): self.http_config = HTTPClientConfig( endpoint_urls=[urljoin(url, "/v1/chat/completions")], warmup_connections=False, diff --git a/tests/performance/endpoint_client/test_http_client_performance_single.py b/tests/performance/endpoint_client/test_http_client_performance_single.py index bc838264..86267391 100644 --- a/tests/performance/endpoint_client/test_http_client_performance_single.py +++ b/tests/performance/endpoint_client/test_http_client_performance_single.py @@ -158,7 +158,9 @@ def run_performance_test( # - Streaming (server mode): Use Poisson scheduler to simulate realistic load arrival # - Offline: Use max throughput scheduler to stress test the system if stream: - sched_cls = PoissonDistributionScheduler + sched_cls: type[PoissonDistributionScheduler] | type[MaxThroughputScheduler] = ( + PoissonDistributionScheduler + ) else: sched_cls = MaxThroughputScheduler scheduler = sched_cls(rt_settings, WithoutReplacementSampleOrder) @@ -240,8 +242,11 @@ def assert_performance_requirements( message_size: Optional message size for better error messages """ achieved_qps = summary["qps"] + assert achieved_qps is not None and isinstance(achieved_qps, float) issued_qps = summary["issue_qps"] + assert issued_qps is not None and isinstance(issued_qps, float) min_achievement = PERFORMANCE_CONFIG["target_qps_tolerance"] + assert min_achievement is not None and isinstance(min_achievement, float) # Log results size_info = f" (size={message_size} characters)" if message_size else "" @@ -335,7 +340,7 @@ def test_offline_baseline_performance(self, http_client): @pytest.mark.performance @pytest.mark.xdist_group(name="serial_performance") - @pytest.mark.parametrize("message_size", PERFORMANCE_CONFIG["message_sizes"]) + @pytest.mark.parametrize("message_size", PERFORMANCE_CONFIG["message_sizes"]) # type: ignore[arg-type] def test_streaming_throughput_various_message_sizes( self, http_client, message_size ): @@ -356,7 +361,7 @@ def test_streaming_throughput_various_message_sizes( @pytest.mark.performance @pytest.mark.xdist_group(name="serial_performance") - @pytest.mark.parametrize("message_size", PERFORMANCE_CONFIG["message_sizes"]) + @pytest.mark.parametrize("message_size", PERFORMANCE_CONFIG["message_sizes"]) # type: ignore[arg-type] def test_offline_throughput_various_message_sizes(self, http_client, message_size): """Validate offline mode maintains max throughput across different message sizes.""" summary = run_performance_test( diff --git a/tests/performance/test_recorder.py b/tests/performance/test_recorder.py index 24eae7a3..f2d9cc3b 100644 --- a/tests/performance/test_recorder.py +++ b/tests/performance/test_recorder.py @@ -18,6 +18,7 @@ import time from dataclasses import dataclass, fields from pathlib import Path +from typing import TextIO import pytest from inference_endpoint.load_generator.events import SampleEvent, SessionEvent @@ -41,7 +42,7 @@ def __init__(self, log_file: Path | str | None = None): if log_file is None: log_file = Path("/tmp/recorder_timing_log.txt") self.log_file = Path(log_file) - self.f_obj = None + self.f_obj: TextIO | None = None def __enter__(self): if self.f_obj is not None: @@ -50,10 +51,12 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): + assert self.f_obj is not None self.f_obj.close() self.f_obj = None def log(self, key: str, duration_sec: float, variant: str = "default"): + assert self.f_obj is not None self.f_obj.write(f"[{key}] {variant}: {duration_sec} sec.\n") diff --git a/tests/test_helpers.py b/tests/test_helpers.py index bc4e33be..589e97df 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -24,6 +24,7 @@ import random import string import uuid +from asyncio import Future from concurrent.futures import ThreadPoolExecutor from pathlib import Path @@ -88,7 +89,7 @@ def create_test_query( query = create_test_query(prompt_size=500, seed=42) """ # Use a local random instance for reproducibility if seed is provided - rng = random.Random(seed) if seed is not None else random + rng = random.Random(seed) if seed is not None else random.Random() # Generate prompt from random words until we reach approximately the target size words = [] @@ -223,7 +224,7 @@ def __init__(self, compute_func=None, n_workers: int = 4): else: self.compute_func = compute_func self.executor = ThreadPoolExecutor(max_workers=n_workers) - self.futures = [] + self.futures: list[Future[None]] = [] def shutdown(self, wait: bool = True): """Shutdown the executor and wait for all tasks to complete.