diff --git a/README.md b/README.md index 79cbf43d..90021257 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,37 @@ data-designer config list # View current settings --- +## Telemetry + +Data Designer collects telemetry to help us improve the library for developers. We collect: + +* The names of models used +* The count of input tokens +* The count of output tokens + +**No user or device information is collected.** This data is not used to track any individual user behavior. It is used to see an aggregation of which models are the most popular for SDG. We will share this usage data with the community. + +Specifically, a model name that is defined a `ModelConfig` object, is what will be collected. In the below example config: + +```python +ModelConfig( + alias="nv-reasoning", + model="openai/gpt-oss-20b", + provider="nvidia", + inference_parameters=InferenceParameters( + temperature=0.3, + top_p=0.9, + max_tokens=4096, + ), + ) +``` + +The value `openai/gpt-oss-20b` would be collected. + +To disable telemetry capture, set `NEMO_TELEMETRY_ENABLED=false`. + +--- + ## License Apache License 2.0 – see [LICENSE](LICENSE) for details. diff --git a/pyproject.toml b/pyproject.toml index bdd783b6..bd4ee6d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,6 +98,11 @@ env = [ "DISABLE_DATA_DESIGNER_PLUGINS=true", ] +[tool.coverage.run] +omit = [ + "src/data_designer/engine/models/telemetry.py", +] + [tool.uv] package = true required-version = ">=0.7.10" diff --git a/src/data_designer/engine/dataset_builders/column_wise_builder.py b/src/data_designer/engine/dataset_builders/column_wise_builder.py index 063aa15c..77b6459b 100644 --- a/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -1,12 +1,15 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations import functools +import importlib.metadata import json import logging import time +import uuid from pathlib import Path -from typing import Callable +from typing import TYPE_CHECKING, Callable import pandas as pd @@ -35,14 +38,21 @@ from data_designer.engine.dataset_builders.utils.dataset_batch_manager import ( DatasetBatchManager, ) +from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum, TelemetryHandler from data_designer.engine.processing.processors.base import Processor from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry from data_designer.engine.resources.resource_provider import ResourceProvider +if TYPE_CHECKING: + from data_designer.engine.models.usage import ModelUsageStats + logger = logging.getLogger(__name__) +_CLIENT_VERSION: str = importlib.metadata.version("data_designer") + + class ColumnWiseDatasetBuilder: def __init__( self, @@ -89,11 +99,12 @@ def build( generators = self._initialize_generators() start_time = time.perf_counter() + group_id = uuid.uuid4().hex self.batch_manager.start(num_records=num_records, buffer_size=buffer_size) for batch_idx in range(self.batch_manager.num_batches): logger.info(f"⏳ Processing batch {batch_idx + 1} of {self.batch_manager.num_batches}") - self._run_batch(generators) + self._run_batch(generators, batch_mode="batch", group_id=group_id) df_batch = self._run_processors( stage=BuildStage.POST_BATCH, dataframe=self.batch_manager.get_current_batch(as_dataframe=True), @@ -114,10 +125,10 @@ def build_preview(self, *, num_records: int) -> pd.DataFrame: self._run_model_health_check_if_needed() generators = self._initialize_generators() - + group_id = uuid.uuid4().hex start_time = time.perf_counter() self.batch_manager.start(num_records=num_records, buffer_size=num_records) - self._run_batch(generators, save_partial_results=False) + self._run_batch(generators, batch_mode="preview", save_partial_results=False, group_id=group_id) dataset = self.batch_manager.get_current_batch(as_dataframe=True) self.batch_manager.reset() @@ -143,7 +154,10 @@ def _initialize_generators(self) -> list[ColumnGenerator]: for config in self._column_configs ] - def _run_batch(self, generators: list[ColumnGenerator], *, save_partial_results: bool = True) -> None: + def _run_batch( + self, generators: list[ColumnGenerator], *, batch_mode: str, save_partial_results: bool = True, group_id: str + ) -> None: + pre_batch_snapshot = self._resource_provider.model_registry.get_model_usage_snapshot() for generator in generators: generator.log_pre_generation() try: @@ -166,6 +180,12 @@ def _run_batch(self, generators: list[ColumnGenerator], *, save_partial_results: ) raise DatasetGenerationError(f"🛑 Failed to process {column_error_str}:\n{e}") + try: + usage_deltas = self._resource_provider.model_registry.get_usage_deltas(pre_batch_snapshot) + self._emit_batch_inference_events(batch_mode, usage_deltas, group_id) + except Exception: + pass + def _run_from_scratch_column_generator(self, generator: ColumnGenerator) -> None: df = generator.generate_from_scratch(self.batch_manager.num_records_batch) self.batch_manager.add_records(df.to_dict(orient="records")) @@ -289,3 +309,25 @@ def _write_configs(self) -> None: json_file_name="model_configs.json", configs=self._resource_provider.model_registry.model_configs.values(), ) + + def _emit_batch_inference_events( + self, batch_mode: str, usage_deltas: dict[str, ModelUsageStats], group_id: str + ) -> None: + if not usage_deltas: + return + + events = [ + InferenceEvent( + nemo_source=NemoSourceEnum.DATADESIGNER, + task=batch_mode, + task_status=TaskStatusEnum.SUCCESS, + model=model_name, + input_tokens=delta.token_usage.input_tokens, + output_tokens=delta.token_usage.output_tokens, + ) + for model_name, delta in usage_deltas.items() + ] + + with TelemetryHandler(source_client_version=_CLIENT_VERSION, session_id=group_id) as telemetry_handler: + for event in events: + telemetry_handler.enqueue(event) diff --git a/src/data_designer/engine/models/registry.py b/src/data_designer/engine/models/registry.py index 8fe61d59..f08e585a 100644 --- a/src/data_designer/engine/models/registry.py +++ b/src/data_designer/engine/models/registry.py @@ -9,6 +9,7 @@ from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.litellm_overrides import apply_litellm_patches +from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats from data_designer.engine.secret_resolver import SecretResolver logger = logging.getLogger(__name__) @@ -25,7 +26,7 @@ def __init__( self._secret_resolver = secret_resolver self._model_provider_registry = model_provider_registry self._model_configs = {} - self._models = {} + self._models: dict[str, ModelFacade] = {} self._set_model_configs(model_configs) @property @@ -69,6 +70,31 @@ def get_model_usage_stats(self, total_time_elapsed: float) -> dict[str, dict]: if model.usage_stats.has_usage } + def get_model_usage_snapshot(self) -> dict[str, ModelUsageStats]: + return { + model.model_name: model.usage_stats.model_copy(deep=True) + for model in self._models.values() + if model.usage_stats.has_usage + } + + def get_usage_deltas(self, snapshot: dict[str, ModelUsageStats]) -> dict[str, ModelUsageStats]: + deltas = {} + for model_name, current in self.get_model_usage_snapshot().items(): + prev = snapshot.get(model_name) + delta_input = current.token_usage.input_tokens - (prev.token_usage.input_tokens if prev else 0) + delta_output = current.token_usage.output_tokens - (prev.token_usage.output_tokens if prev else 0) + delta_successful = current.request_usage.successful_requests - ( + prev.request_usage.successful_requests if prev else 0 + ) + delta_failed = current.request_usage.failed_requests - (prev.request_usage.failed_requests if prev else 0) + + if delta_input > 0 or delta_output > 0 or delta_successful > 0 or delta_failed > 0: + deltas[model_name] = ModelUsageStats( + token_usage=TokenUsageStats(input_tokens=delta_input, output_tokens=delta_output), + request_usage=RequestUsageStats(successful_requests=delta_successful, failed_requests=delta_failed), + ) + return deltas + def get_model_provider(self, *, model_alias: str) -> ModelProvider: model_config = self.get_model_config(model_alias=model_alias) return self._model_provider_registry.get_provider(model_config.provider) diff --git a/src/data_designer/engine/models/telemetry.py b/src/data_designer/engine/models/telemetry.py new file mode 100644 index 00000000..23b3b7be --- /dev/null +++ b/src/data_designer/engine/models/telemetry.py @@ -0,0 +1,355 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Telemetry handler for NeMo products. + +Environment variables: +- NEMO_TELEMETRY_ENABLED: Whether telemetry is enabled. +- NEMO_DEPLOYMENT_TYPE: The deployment type the event came from. +- NEMO_TELEMETRY_ENDPOINT: The endpoint to send the telemetry events to. +""" + +from __future__ import annotations + +import asyncio +import os +import platform +from dataclasses import dataclass +from datetime import datetime, timezone +from enum import Enum +from typing import Any, ClassVar + +import httpx +from pydantic import BaseModel, Field + +TELEMETRY_ENABLED = os.getenv("NEMO_TELEMETRY_ENABLED", "true").lower() in ("1", "true", "yes") +CLIENT_ID = "184482118588404" +NEMO_TELEMETRY_VERSION = "nemo-telemetry/1.0" +MAX_RETRIES = 3 +NEMO_TELEMETRY_ENDPOINT = os.getenv( + "NEMO_TELEMETRY_ENDPOINT", "https://events.telemetry.data.nvidia.com/v1.1/events/json" +).lower() +CPU_ARCHITECTURE = platform.uname().machine + + +class NemoSourceEnum(str, Enum): + INFERENCE = "inference" + AUDITOR = "auditor" + DATADESIGNER = "datadesigner" + EVALUATOR = "evaluator" + GUARDRAILS = "guardrails" + UNDEFINED = "undefined" + + +class DeploymentTypeEnum(str, Enum): + LIBRARY = "library" + API = "api" + UNDEFINED = "undefined" + + +_deployment_type_raw = os.getenv("NEMO_DEPLOYMENT_TYPE", "library").lower() +try: + DEPLOYMENT_TYPE = DeploymentTypeEnum(_deployment_type_raw) +except ValueError: + valid_values = [e.value for e in DeploymentTypeEnum] + raise ValueError( + f"Invalid NEMO_DEPLOYMENT_TYPE: {_deployment_type_raw!r}. Must be one of: {valid_values}" + ) from None + + +class TaskStatusEnum(str, Enum): + SUCCESS = "success" + FAILURE = "failure" + UNDEFINED = "undefined" + + +class TelemetryEvent(BaseModel): + _event_name: ClassVar[str] # Subclasses must define this + _schema_version: ClassVar[str] = "1.3" + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + if "_event_name" not in cls.__dict__: + raise TypeError(f"{cls.__name__} must define '_event_name' class variable") + + +class InferenceEvent(TelemetryEvent): + _event_name: ClassVar[str] = "inference_event" + + nemo_source: NemoSourceEnum = Field( + ..., + alias="nemoSource", + description="The NeMo product that created the event (i.e. data-designer).", + ) + task: str = Field( + ..., + description="The type of task that was performed that generated the inference event (i.e. preview-job, batch-job).", + ) + task_status: TaskStatusEnum = Field( + ..., + alias="taskStatus", + description="The status of the task.", + ) + deployment_type: DeploymentTypeEnum = Field( + default=DEPLOYMENT_TYPE, + alias="deploymentType", + description="The deployment type the event came from.", + ) + model: str = Field( + ..., + description="The name of the model that was used.", + ) + model_group: str = Field( + default="undefined", + alias="modelGroup", + description="An optional identifier to group models together.", + ) + input_bytes: int = Field( + default=-1, + alias="inputBytes", + description="Number of bytes provided as input to the model. -1 if not available.", + ge=-9223372036854775808, + le=9223372036854775807, + ) + input_tokens: int = Field( + default=-1, + alias="inputTokens", + description="Number of tokens provided as input to the model. -1 if not available.", + ge=-9223372036854775808, + le=9223372036854775807, + ) + output_bytes: int = Field( + default=-1, + alias="outputBytes", + description="Number of bytes returned by the model. -1 if not available.", + ge=-9223372036854775808, + le=9223372036854775807, + ) + output_tokens: int = Field( + default=-1, + alias="outputTokens", + description="Number of tokens returned by the model. -1 if not available.", + ge=-9223372036854775808, + le=9223372036854775807, + ) + + model_config = {"populate_by_name": True} + + +@dataclass +class QueuedEvent: + event: TelemetryEvent + timestamp: datetime + retry_count: int = 0 + + +def _get_iso_timestamp(dt: datetime | None = None) -> str: + if dt is None: + dt = datetime.now(timezone.utc) + return dt.strftime("%Y-%m-%dT%H:%M:%S.") + f"{dt.microsecond // 1000:03d}Z" + + +def build_payload( + events: list[QueuedEvent], *, source_client_version: str, session_id: str = "undefined" +) -> dict[str, Any]: + return { + "browserType": "undefined", # do not change + "clientId": CLIENT_ID, + "clientType": "Native", # do not change + "clientVariant": "Release", # do not change + "clientVer": source_client_version, + "cpuArchitecture": CPU_ARCHITECTURE, + "deviceGdprBehOptIn": "None", # do not change + "deviceGdprFuncOptIn": "None", # do not change + "deviceGdprTechOptIn": "None", # do not change + "deviceId": "undefined", # do not change + "deviceMake": "undefined", # do not change + "deviceModel": "undefined", # do not change + "deviceOS": "undefined", # do not change + "deviceOSVersion": "undefined", # do not change + "deviceType": "undefined", # do not change + "eventProtocol": "1.6", # do not change + "eventSchemaVer": events[0].event._schema_version, + "eventSysVer": NEMO_TELEMETRY_VERSION, + "externalUserId": "undefined", # do not change + "gdprBehOptIn": "None", # do not change + "gdprFuncOptIn": "None", # do not change + "gdprTechOptIn": "None", # do not change + "idpId": "undefined", # do not change + "integrationId": "undefined", # do not change + "productName": "undefined", # do not change + "productVersion": "undefined", # do not change + "sentTs": _get_iso_timestamp(), + "sessionId": session_id, + "userId": "undefined", # do not change + "events": [ + { + "ts": _get_iso_timestamp(queued.timestamp), + "parameters": queued.event.model_dump(by_alias=True), + "name": queued.event._event_name, + } + for queued in events + ], + } + + +class TelemetryHandler: + """ + Handles telemetry event batching, flushing, and retry logic for NeMo products. + + Args: + flush_interval_seconds (float): The interval in seconds to flush the events. + max_queue_size (int): The maximum number of events to queue before flushing. + max_retries (int): The maximum number of times to retry sending an event. + source_client_version (str): The version of the source client. This should be the version of + the actual NeMo product that is sending the events, typically the same as the version of + a PyPi package that a user would install. + session_id (str): An optional session ID to associate with the events. + This should be a unique identifier for the session, such as a UUID. + It is used to group events together. + """ + + def __init__( + self, + flush_interval_seconds: float = 120.0, + max_queue_size: int = 50, + max_retries: int = MAX_RETRIES, + source_client_version: str = "undefined", + session_id: str = "undefined", + ): + self._flush_interval = flush_interval_seconds + self._max_queue_size = max_queue_size + self._max_retries = max_retries + self._events: list[QueuedEvent] = [] + self._dlq: list[QueuedEvent] = [] # Dead letter queue for retry + self._flush_signal = asyncio.Event() + self._timer_task: asyncio.Task | None = None + self._running = False + self._source_client_version = source_client_version + self._session_id = session_id + + async def astart(self) -> None: + if self._running: + return + self._running = True + self._timer_task = asyncio.create_task(self._timer_loop()) + + async def astop(self) -> None: + self._running = False + self._flush_signal.set() + if self._timer_task: + self._timer_task.cancel() + try: + await self._timer_task + except asyncio.CancelledError: + pass + self._timer_task = None + await self._flush_events() + + async def aflush(self) -> None: + self._flush_signal.set() + + def start(self) -> None: + self._run_sync(self.astart()) + + def stop(self) -> None: + self._run_sync(self.astop()) + + def flush(self) -> None: + self._flush_signal.set() + + def enqueue(self, event: TelemetryEvent) -> None: + if not TELEMETRY_ENABLED: + return + if not isinstance(event, TelemetryEvent): + # Silently fail as we prioritize not disrupting upstream call sites and telemetry is best effort + return + queued = QueuedEvent(event=event, timestamp=datetime.now(timezone.utc)) + self._events.append(queued) + if len(self._events) >= self._max_queue_size: + self._flush_signal.set() + + def _run_sync(self, coro: Any) -> Any: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as pool: + future = pool.submit(asyncio.run, coro) + return future.result() + else: + return asyncio.run(coro) + + def __enter__(self) -> TelemetryHandler: + self.start() + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.stop() + + async def __aenter__(self) -> TelemetryHandler: + await self.astart() + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + await self.astop() + + async def _timer_loop(self) -> None: + while self._running: + try: + await asyncio.wait_for( + self._flush_signal.wait(), + timeout=self._flush_interval, + ) + except asyncio.TimeoutError: + pass + self._flush_signal.clear() + await self._flush_events() + + async def _flush_events(self) -> None: + dlq_events, self._dlq = self._dlq, [] + new_events, self._events = self._events, [] + events_to_send = dlq_events + new_events + if events_to_send: + await self._send_events(events_to_send) + + async def _send_events(self, events: list[QueuedEvent]) -> None: + async with httpx.AsyncClient() as client: + await self._send_events_with_client(client, events) + + async def _send_events_with_client(self, client: httpx.AsyncClient, events: list[QueuedEvent]) -> None: + if not events: + return + + payload = build_payload(events, source_client_version=self._source_client_version, session_id=self._session_id) + try: + response = await client.post(NEMO_TELEMETRY_ENDPOINT, json=payload) + # 2xx, 400, 422 are all considered complete (no retry) + # 400/422 indicate bad payload which retrying won't fix + if response.status_code in (400, 422) or response.is_success: + return + # 413 (payload too large) - split and retry + if response.status_code == 413: + if len(events) == 1: + # Can't split further, drop the event + return + mid = len(events) // 2 + await self._send_events_with_client(client, events[:mid]) + await self._send_events_with_client(client, events[mid:]) + return + if response.status_code == 408 or response.status_code >= 500: + self._add_to_dlq(events) + except httpx.HTTPError: + self._add_to_dlq(events) + + def _add_to_dlq(self, events: list[QueuedEvent]) -> None: + for queued in events: + queued.retry_count += 1 + if queued.retry_count > self._max_retries: + continue + self._dlq.append(queued) diff --git a/tests/engine/dataset_builders/test_column_wise_builder.py b/tests/engine/dataset_builders/test_column_wise_builder.py index 4e62c478..dd8c84db 100644 --- a/tests/engine/dataset_builders/test_column_wise_builder.py +++ b/tests/engine/dataset_builders/test_column_wise_builder.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import Mock +from unittest.mock import Mock, patch import pandas as pd import pytest @@ -15,6 +15,8 @@ ) from data_designer.engine.dataset_builders.errors import DatasetGenerationError from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig +from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum +from data_designer.engine.models.usage import ModelUsageStats, TokenUsageStats from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry @@ -147,6 +149,7 @@ def test_column_wise_dataset_builder_build_method_basic_flow( ): stub_resource_provider.model_registry.run_health_check = Mock() stub_resource_provider.model_registry.get_model_usage_stats = Mock(return_value={"test": "stats"}) + stub_resource_provider.model_registry.models = {} # Mock the model config to return proper max_parallel_requests mock_model_config = Mock() @@ -212,3 +215,94 @@ def test_column_wise_dataset_builder_initialize_processors(stub_column_wise_buil def test_constants_max_concurrency_constant(): assert MAX_CONCURRENCY_PER_NON_LLM_GENERATOR == 4 + + +@patch("data_designer.engine.dataset_builders.column_wise_builder.TelemetryHandler") +def test_emit_batch_inference_events_emits_from_deltas( + mock_telemetry_handler_class: Mock, + stub_resource_provider: Mock, + stub_test_column_configs: list, + stub_test_processor_configs: list, +) -> None: + usage_deltas = {"test-model": ModelUsageStats(token_usage=TokenUsageStats(input_tokens=50, output_tokens=150))} + + builder = ColumnWiseDatasetBuilder( + column_configs=stub_test_column_configs, + processor_configs=stub_test_processor_configs, + resource_provider=stub_resource_provider, + ) + + session_id = "550e8400-e29b-41d4-a716-446655440000" + + mock_handler_instance = Mock() + mock_telemetry_handler_class.return_value.__enter__ = Mock(return_value=mock_handler_instance) + mock_telemetry_handler_class.return_value.__exit__ = Mock(return_value=False) + + builder._emit_batch_inference_events("batch", usage_deltas, session_id) + + mock_telemetry_handler_class.assert_called_once() + call_kwargs = mock_telemetry_handler_class.call_args[1] + assert call_kwargs["session_id"] == session_id + + mock_handler_instance.enqueue.assert_called_once() + event = mock_handler_instance.enqueue.call_args[0][0] + + assert isinstance(event, InferenceEvent) + assert event.task == "batch" + assert event.task_status == TaskStatusEnum.SUCCESS + assert event.nemo_source == NemoSourceEnum.DATADESIGNER + assert event.model == "test-model" + assert event.input_tokens == 50 + assert event.output_tokens == 150 + + +@patch("data_designer.engine.dataset_builders.column_wise_builder.TelemetryHandler") +def test_emit_batch_inference_events_skips_when_no_deltas( + mock_telemetry_handler_class: Mock, + stub_resource_provider: Mock, + stub_test_column_configs: list, + stub_test_processor_configs: list, +) -> None: + usage_deltas: dict[str, ModelUsageStats] = {} + + builder = ColumnWiseDatasetBuilder( + column_configs=stub_test_column_configs, + processor_configs=stub_test_processor_configs, + resource_provider=stub_resource_provider, + ) + + session_id = "550e8400-e29b-41d4-a716-446655440000" + builder._emit_batch_inference_events("batch", usage_deltas, session_id) + + mock_telemetry_handler_class.assert_not_called() + + +@patch("data_designer.engine.dataset_builders.column_wise_builder.TelemetryHandler") +def test_emit_batch_inference_events_handles_multiple_models( + mock_telemetry_handler_class: Mock, + stub_resource_provider: Mock, + stub_test_column_configs: list, + stub_test_processor_configs: list, +) -> None: + usage_deltas = { + "model-a": ModelUsageStats(token_usage=TokenUsageStats(input_tokens=100, output_tokens=200)), + "model-b": ModelUsageStats(token_usage=TokenUsageStats(input_tokens=50, output_tokens=75)), + } + + builder = ColumnWiseDatasetBuilder( + column_configs=stub_test_column_configs, + processor_configs=stub_test_processor_configs, + resource_provider=stub_resource_provider, + ) + + session_id = "550e8400-e29b-41d4-a716-446655440000" + mock_handler_instance = Mock() + mock_telemetry_handler_class.return_value.__enter__ = Mock(return_value=mock_handler_instance) + mock_telemetry_handler_class.return_value.__exit__ = Mock(return_value=False) + + builder._emit_batch_inference_events("preview", usage_deltas, session_id) + + assert mock_handler_instance.enqueue.call_count == 2 + events = [call[0][0] for call in mock_handler_instance.enqueue.call_args_list] + model_names = {e.model for e in events} + assert model_names == {"model-a", "model-b"} diff --git a/tests/engine/models/test_model_registry.py b/tests/engine/models/test_model_registry.py index 4ea5a447..98814517 100644 --- a/tests/engine/models/test_model_registry.py +++ b/tests/engine/models/test_model_registry.py @@ -9,7 +9,7 @@ from data_designer.engine.models.errors import ModelAuthenticationError from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.registry import ModelRegistry, create_model_registry -from data_designer.engine.models.usage import RequestUsageStats, TokenUsageStats +from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats @pytest.fixture @@ -157,6 +157,120 @@ def test_get_model_usage_stats( assert set(usage_stats.keys()) == set(expected_keys) +@pytest.mark.parametrize( + "test_case,expected_keys", + [ + ("no_models", []), + ("with_usage", ["stub-model-text", "stub-model-reasoning"]), + ("no_usage", []), + ], +) +def test_get_model_usage_snapshot( + stub_model_registry: ModelRegistry, + stub_empty_model_registry: ModelRegistry, + test_case: str, + expected_keys: list[str], +) -> None: + if test_case == "no_models": + snapshot = stub_empty_model_registry.get_model_usage_snapshot() + assert snapshot == {} + elif test_case == "with_usage": + text_model = stub_model_registry.get_model(model_alias="stub-text") + reasoning_model = stub_model_registry.get_model(model_alias="stub-reasoning") + + text_model.usage_stats.extend( + token_usage=TokenUsageStats(input_tokens=10, output_tokens=100), + request_usage=RequestUsageStats(successful_requests=5, failed_requests=1), + ) + reasoning_model.usage_stats.extend( + token_usage=TokenUsageStats(input_tokens=20, output_tokens=200), + request_usage=RequestUsageStats(successful_requests=10, failed_requests=2), + ) + + snapshot = stub_model_registry.get_model_usage_snapshot() + + assert set(snapshot.keys()) == set(expected_keys) + assert all(isinstance(stats, ModelUsageStats) for stats in snapshot.values()) + + assert snapshot["stub-model-text"].token_usage.input_tokens == 10 + assert snapshot["stub-model-text"].token_usage.output_tokens == 100 + assert snapshot["stub-model-reasoning"].token_usage.input_tokens == 20 + assert snapshot["stub-model-reasoning"].token_usage.output_tokens == 200 + + snapshot["stub-model-text"].token_usage.input_tokens = 999 + assert text_model.usage_stats.token_usage.input_tokens == 10 + else: + stub_model_registry.get_model(model_alias="stub-text") + stub_model_registry.get_model(model_alias="stub-reasoning") + + snapshot = stub_model_registry.get_model_usage_snapshot() + assert snapshot == {} + + +@pytest.mark.parametrize( + "test_case,expected_keys", + [ + ("no_prior_usage", ["stub-model-text"]), + ("with_prior_usage", ["stub-model-text"]), + ("no_change", []), + ], +) +def test_get_usage_deltas( + stub_model_registry: ModelRegistry, + test_case: str, + expected_keys: list[str], +) -> None: + text_model = stub_model_registry.get_model(model_alias="stub-text") + + if test_case == "no_prior_usage": + # Empty snapshot, then add usage + pre_snapshot: dict[str, ModelUsageStats] = {} + text_model.usage_stats.extend( + token_usage=TokenUsageStats(input_tokens=50, output_tokens=100), + request_usage=RequestUsageStats(successful_requests=5, failed_requests=1), + ) + + deltas = stub_model_registry.get_usage_deltas(pre_snapshot) + + assert set(deltas.keys()) == set(expected_keys) + assert deltas["stub-model-text"].token_usage.input_tokens == 50 + assert deltas["stub-model-text"].token_usage.output_tokens == 100 + assert deltas["stub-model-text"].request_usage.successful_requests == 5 + assert deltas["stub-model-text"].request_usage.failed_requests == 1 + + elif test_case == "with_prior_usage": + # Add initial usage, take snapshot, add more usage + text_model.usage_stats.extend( + token_usage=TokenUsageStats(input_tokens=100, output_tokens=200), + request_usage=RequestUsageStats(successful_requests=10, failed_requests=2), + ) + pre_snapshot = stub_model_registry.get_model_usage_snapshot() + + text_model.usage_stats.extend( + token_usage=TokenUsageStats(input_tokens=50, output_tokens=75), + request_usage=RequestUsageStats(successful_requests=3, failed_requests=1), + ) + + deltas = stub_model_registry.get_usage_deltas(pre_snapshot) + + assert set(deltas.keys()) == set(expected_keys) + assert deltas["stub-model-text"].token_usage.input_tokens == 50 + assert deltas["stub-model-text"].token_usage.output_tokens == 75 + assert deltas["stub-model-text"].request_usage.successful_requests == 3 + assert deltas["stub-model-text"].request_usage.failed_requests == 1 + + else: # no_change + text_model.usage_stats.extend( + token_usage=TokenUsageStats(input_tokens=100, output_tokens=200), + request_usage=RequestUsageStats(successful_requests=10, failed_requests=2), + ) + pre_snapshot = stub_model_registry.get_model_usage_snapshot() + + # No additional usage after snapshot + deltas = stub_model_registry.get_usage_deltas(pre_snapshot) + assert deltas == {} + + @patch("data_designer.engine.models.facade.ModelFacade.generate_text_embeddings", autospec=True) @patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) def test_run_health_check_success(mock_completion, mock_generate_text_embeddings, stub_model_registry):