diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000..23709de7 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,29 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal" + }, + { + "name": "Run worker", + "type": "debugpy", + "request": "launch", + "module": "trapdata.cli.base", + "args": ["worker"] + }, + { + "name": "Run api", + "type": "debugpy", + "request": "launch", + "module": "trapdata.cli.base", + "args": ["api"] + } + ] +} diff --git a/trapdata/api/api.py b/trapdata/api/api.py index 632323cf..5d42f8fd 100644 --- a/trapdata/api/api.py +++ b/trapdata/api/api.py @@ -5,6 +5,7 @@ import enum import time +from contextlib import asynccontextmanager import fastapi import pydantic @@ -36,7 +37,18 @@ from .schemas import PipelineResultsResponse as PipelineResponse_ from .schemas import ProcessingServiceInfoResponse, SourceImage, SourceImageResponse -app = fastapi.FastAPI() + +@asynccontextmanager +async def lifespan(app: fastapi.FastAPI): + # cache the service info to be built only once at startup + app.state.service_info = initialize_service_info() + logger.info("Initialized service info") + yield + # Shutdown event: Clean up resources (if necessary) + logger.info("Shutting down API") + + +app = fastapi.FastAPI(lifespan=lifespan) app.add_middleware(GZipMiddleware) @@ -157,13 +169,6 @@ def make_pipeline_config_response( ) -# @TODO This requires loading all models into memory! Can we avoid this? -PIPELINE_CONFIGS = [ - make_pipeline_config_response(classifier_class, slug=key) - for key, classifier_class in CLASSIFIER_CHOICES.items() -] - - class PipelineRequest(PipelineRequest_): pipeline: PipelineChoice = pydantic.Field( description=PipelineRequest_.model_fields["pipeline"].description, @@ -313,17 +318,7 @@ async def process(data: PipelineRequest) -> PipelineResponse: @app.get("/info", tags=["services"]) async def info() -> ProcessingServiceInfoResponse: - info = ProcessingServiceInfoResponse( - name="Antenna Inference API", - description=( - "The primary endpoint for processing images for the Antenna platform. " - "This API provides access to multiple detection and classification " - "algorithms by multiple labs for processing images of moths." - ), - pipelines=PIPELINE_CONFIGS, - # algorithms=list(algorithm_choices.values()), - ) - return info + return app.state.service_info # Check if the server is online @@ -361,6 +356,26 @@ async def readyz(): # pass +def initialize_service_info() -> ProcessingServiceInfoResponse: + # @TODO This requires loading all models into memory! Can we avoid this? + pipeline_configs = [ + make_pipeline_config_response(classifier_class, slug=key) + for key, classifier_class in CLASSIFIER_CHOICES.items() + ] + + _info = ProcessingServiceInfoResponse( + name="Antenna Inference API", + description=( + "The primary endpoint for processing images for the Antenna platform. " + "This API provides access to multiple detection and classification " + "algorithms by multiple labs for processing images of moths." + ), + pipelines=pipeline_configs, + # algorithms=list(algorithm_choices.values()), + ) + return _info + + if __name__ == "__main__": import uvicorn diff --git a/trapdata/api/datasets.py b/trapdata/api/datasets.py index 57cf9ba1..c0c96d0b 100644 --- a/trapdata/api/datasets.py +++ b/trapdata/api/datasets.py @@ -1,12 +1,16 @@ +import os import typing +from io import BytesIO +import requests import torch import torch.utils.data import torchvision +from PIL import Image from trapdata.common.logs import logger -from .schemas import DetectionResponse, SourceImage +from .schemas import DetectionResponse, PipelineProcessingTask, SourceImage class LocalizationImageDataset(torch.utils.data.Dataset): @@ -87,3 +91,238 @@ def __getitem__(self, idx): # return (ids_batch, image_batch) return (source_image.id, detection_idx), image_data + + +class RESTDataset(torch.utils.data.IterableDataset): + """ + An IterableDataset that fetches tasks from a REST API endpoint and loads images. + + The dataset continuously polls the API for tasks, loads the associated images, + and yields them as PyTorch tensors along with metadata. + """ + + def __init__( + self, + base_url: str, + job_id: int, + batch_size: int = 1, + image_transforms: typing.Optional[torchvision.transforms.Compose] = None, + auth_token: typing.Optional[str] = None, + ): + """ + Initialize the REST dataset. + + Args: + base_url: Base URL for the API (e.g., "http://localhost:8000") + job_id: The job ID to fetch tasks for + batch_size: Number of tasks to request per batch + image_transforms: Optional transforms to apply to loaded images + auth_token: API authentication token. If not provided, reads from + ANTENNA_API_TOKEN environment variable + """ + super().__init__() + self.base_url = base_url.rstrip("/") + self.job_id = job_id + self.batch_size = batch_size + self.image_transforms = image_transforms or torchvision.transforms.ToTensor() + self.auth_token = auth_token or os.environ.get("ANTENNA_API_TOKEN") + + def _fetch_tasks(self) -> list[PipelineProcessingTask]: + """ + Fetch a batch of tasks from the REST API. + + Returns: + List of task dictionaries from the API response + """ + url = f"{self.base_url}/api/v2/jobs/{self.job_id}/tasks" + params = {"batch": self.batch_size} + + headers = {} + if self.auth_token: + headers["Authorization"] = f"Token {self.auth_token}" + + try: + response = requests.get( + url, + params=params, + timeout=30, + headers=headers, + ) + response.raise_for_status() + data = response.json() + tasks = [PipelineProcessingTask(**task) for task in data.get("tasks", [])] + return tasks + except requests.RequestException as e: + logger.error(f"Failed to fetch tasks from {url}: {e}") + return [] + + def _load_image(self, image_url: str) -> typing.Optional[torch.Tensor]: + """ + Load an image from a URL and convert it to a PyTorch tensor. + + Args: + image_url: URL of the image to load + + Returns: + Image as a PyTorch tensor, or None if loading failed + """ + try: + response = requests.get(image_url, timeout=30) + response.raise_for_status() + image = Image.open(BytesIO(response.content)) + + # Convert to RGB if necessary + if image.mode != "RGB": + image = image.convert("RGB") + + # Apply transforms + image_tensor = self.image_transforms(image) + return image_tensor + except Exception as e: + logger.error(f"Failed to load image from {image_url}: {e}") + return None + + def __iter__(self): + """ + Iterate over tasks from the REST API. + + Yields: + Dictionary containing: + - image: PyTorch tensor of the loaded image + - reply_subject: Reply subject for the task + - batch_index: Index of the image in the batch + - job_id: Job ID + - image_id: Image ID + """ + try: + # Get worker info for debugging + worker_info = torch.utils.data.get_worker_info() + worker_id = worker_info.id if worker_info else 0 + num_workers = worker_info.num_workers if worker_info else 1 + + logger.info( + f"Worker {worker_id}/{num_workers} starting iteration for job {self.job_id}" + ) + + while True: + tasks = self._fetch_tasks() + # _, t = log_time() + # _, t = t(f"Worker {worker_id}: Fetched {len(tasks)} tasks from API") + + # If no tasks returned, dataset is finished + if not tasks: + logger.info( + f"Worker {worker_id}: No more tasks for job {self.job_id}, terminating" + ) + break + + for task in tasks: + errors = [] + # Load the image + # _, t = log_time() + image_tensor = ( + self._load_image(task.image_url) if task.image_url else None + ) + # _, t = t(f"Loaded image from {image_url}") + + if image_tensor is None: + errors.append("failed to load image") + + if errors: + logger.warning( + f"Worker {worker_id}: Errors in task for image '{task.image_id}': {', '.join(errors)}" + ) + + # Yield the data row + row = { + "image": image_tensor, + "reply_subject": task.reply_subject, + "image_id": task.image_id, + "image_url": task.image_url, + } + if errors: + row["error"] = ("; ".join(errors) if errors else None,) + yield row + + logger.info(f"Worker {worker_id}: Iterator finished") + except Exception as e: + logger.error(f"Worker {worker_id}: Exception in iterator: {e}") + raise + + +def rest_collate_fn(batch: list[dict]) -> dict: + """ + Custom collate function that separates failed and successful items. + + Returns a dict with: + - images: List of valid tensors + - reply_subjects: List of reply subjects for valid images + - image_ids: List of image IDs for valid images + - image_urls: List of image URLs for valid images + - failed_items: List of dicts with metadata for failed items + """ + successful = [] + failed = [] + + for item in batch: + if item["image"] is None or item.get("error"): + # Failed item + failed.append( + { + "reply_subject": item["reply_subject"], + "image_id": item["image_id"], + "image_url": item.get("image_url"), + "error": item.get("error", "Unknown error"), + } + ) + else: + # Successful item + successful.append(item) + + # Collate successful items + if successful: + result = { + "image": torch.stack([item["image"] for item in successful]), + "reply_subject": [item["reply_subject"] for item in successful], + "image_id": [item["image_id"] for item in successful], + "image_url": [item.get("image_url") for item in successful], + } + else: + # Empty batch - all failed + result = { + "reply_subject": [], + "image_id": [], + } + + result["failed_items"] = failed + + return result + + +def get_rest_dataloader( + job_id: int, + base_url: str = "http://localhost:8000", + batch_size: int = 4, + num_workers: int = 2, + auth_token: typing.Optional[str] = None, +) -> torch.utils.data.DataLoader: + """ + Args: + base_url: Base URL for the REST API (default: http://localhost:8000) + job_id: Job id to fetch tasks for (default: 11) + batch_size: Number of tasks/images per batch (default: 4) + num_workers: Number of DataLoader workers (default: 2) + """ + assert base_url is not None, "Base URL must be provided" + base_url = base_url.rstrip("/") + + dataset = RESTDataset( + base_url=base_url, job_id=job_id, batch_size=batch_size, auth_token=auth_token + ) + + return torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + collate_fn=rest_collate_fn, + ) diff --git a/trapdata/api/models/classification.py b/trapdata/api/models/classification.py index 482c4ac3..4e4ddfd3 100644 --- a/trapdata/api/models/classification.py +++ b/trapdata/api/models/classification.py @@ -54,6 +54,10 @@ def __init__( "detections" ) + def reset(self, detections: typing.Iterable[DetectionResponse]): + self.detections = list(detections) + self.results = [] + def get_dataset(self): return ClassificationImageDataset( source_images=self.source_images, @@ -117,19 +121,12 @@ def save_results( for image_id, detection_idx, predictions in zip( image_ids, detection_idxes, batch_output ): - detection = self.detections[detection_idx] - assert detection.source_image_id == image_id - - classification = ClassificationResponse( - classification=self.get_best_label(predictions), - scores=predictions.scores, - logits=predictions.logit, - inference_time=seconds_per_item, - algorithm=AlgorithmReference(name=self.name, key=self.get_key()), - timestamp=datetime.datetime.now(), - terminal=self.terminal, + self.update_detection_classification( + seconds_per_item, + image_id, + detection_idx, + predictions, ) - self.update_classification(detection, classification) self.results = self.detections logger.info(f"Saving {len(self.results)} detections with classifications") @@ -149,6 +146,24 @@ def update_classification( f"Total classifications: {len(detection.classifications)}" ) + def update_detection_classification( + self, seconds_per_item, image_id, detection_idx, predictions + ): + detection = self.detections[detection_idx] + assert detection.source_image_id == image_id + + classification = ClassificationResponse( + classification=self.get_best_label(predictions), + scores=predictions.scores, + logits=predictions.logit, + inference_time=seconds_per_item, + algorithm=AlgorithmReference(name=self.name, key=self.get_key()), + timestamp=datetime.datetime.now(), + terminal=self.terminal, + ) + self.update_classification(detection, classification) + return detection + def run(self) -> list[DetectionResponse]: logger.info( f"Starting {self.__class__.__name__} run with {len(self.results)} " diff --git a/trapdata/api/models/localization.py b/trapdata/api/models/localization.py index 600fc9f7..9ec1acd5 100644 --- a/trapdata/api/models/localization.py +++ b/trapdata/api/models/localization.py @@ -1,4 +1,3 @@ -import concurrent.futures import datetime import typing @@ -17,6 +16,10 @@ def __init__(self, source_images: typing.Iterable[SourceImage], *args, **kwargs) self.results: list[DetectionResponse] = [] super().__init__(*args, **kwargs) + def reset(self, source_images: typing.Iterable[SourceImage]): + self.source_images = source_images + self.results = [] + def get_dataset(self): return LocalizationImageDataset( self.source_images, self.get_transforms(), batch_size=self.batch_size @@ -43,15 +46,9 @@ def save_detection(image_id, coords): ) return detection - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [] - for image_id, image_output in zip(item_ids, batch_output): - for coords in image_output: - future = executor.submit(save_detection, image_id, coords) - futures.append(future) - - for future in concurrent.futures.as_completed(futures): - detection = future.result() + for image_id, image_output in zip(item_ids, batch_output): + for coords in image_output: + detection = save_detection(image_id, coords) detections.append(detection) self.results += detections diff --git a/trapdata/api/schemas.py b/trapdata/api/schemas.py index a8b682ac..5fee5fbb 100644 --- a/trapdata/api/schemas.py +++ b/trapdata/api/schemas.py @@ -282,6 +282,20 @@ class PipelineResultsResponse(pydantic.BaseModel): config: PipelineConfigRequest = PipelineConfigRequest() +class PipelineProcessingTask(pydantic.BaseModel): + """ + A task representing a single image or detection to be processed in an async pipeline. + """ + + id: str + image_id: str + image_url: str + reply_subject: str | None = None # The NATS subject to send the result to + # TODO: Do we need these? + # detections: list[DetectionRequest] | None = None + # config: PipelineRequestConfigParameters | dict | None = None + + class PipelineStageParam(pydantic.BaseModel): """A configurable parameter of a stage of a pipeline.""" diff --git a/trapdata/cli/base.py b/trapdata/cli/base.py index f53cb651..65dddd72 100644 --- a/trapdata/cli/base.py +++ b/trapdata/cli/base.py @@ -1,8 +1,9 @@ import pathlib -from typing import Optional +from typing import List, Optional import typer +from trapdata.api.api import CLASSIFIER_CHOICES from trapdata.cli import db, export, queue, settings, shell, show, test from trapdata.db.base import get_session_class from trapdata.db.models.events import get_or_create_monitoring_sessions @@ -96,5 +97,30 @@ def run_api(port: int = 2000): uvicorn.run("trapdata.api.api:app", host="0.0.0.0", port=port, reload=True) +@cli.command("worker") +def worker( + pipelines: List[str] = typer.Option( + ["moth_binary"], # Default to a list with one pipeline + help="List of pipelines to use for processing (e.g., moth_binary, panama_moths_2024, etc.)", + ) +): + """ + Run the worker to process images from the REST API queue. + """ + # Validate that each pipeline is in CLASSIFIER_CHOICES + invalid_pipelines = [ + pipeline for pipeline in pipelines if pipeline not in CLASSIFIER_CHOICES.keys() + ] + + if invalid_pipelines: + raise typer.BadParameter( + f"Invalid pipeline(s): {', '.join(invalid_pipelines)}. Must be one of: {', '.join(CLASSIFIER_CHOICES.keys())}" + ) + + from trapdata.cli.worker import run_worker + + run_worker(pipelines=pipelines) + + if __name__ == "__main__": cli() diff --git a/trapdata/cli/worker.py b/trapdata/cli/worker.py new file mode 100644 index 00000000..49506049 --- /dev/null +++ b/trapdata/cli/worker.py @@ -0,0 +1,275 @@ +"""Worker to process images from the REST API queue.""" + +import datetime +import os +import time +from typing import List + +import numpy as np +import requests +import torch + +from trapdata.api.api import CLASSIFIER_CHOICES +from trapdata.api.datasets import get_rest_dataloader +from trapdata.api.models.localization import APIMothDetector +from trapdata.api.schemas import ( + DetectionResponse, + PipelineResultsResponse, + SourceImageResponse, +) +from trapdata.common.logs import logger +from trapdata.common.utils import log_time + +SLEEP_TIME_SECONDS = 5 + + +def post_batch_results( + base_url: str, job_id: int, results: list[dict], auth_token: str = None +) -> bool: + """ + Post batch results back to the API. + + Args: + base_url: Base URL for the API + job_id: Job ID + results: List of dicts containing reply_subject and image_id + auth_token: API authentication token + + Returns: + True if successful, False otherwise + """ + url = f"{base_url}/api/v2/jobs/{job_id}/result/" + + headers = {} + if auth_token: + headers["Authorization"] = f"Token {auth_token}" + + try: + response = requests.post(url, json=results, headers=headers, timeout=60) + response.raise_for_status() + logger.info(f"Successfully posted {len(results)} results to {url}") + return True + except requests.RequestException as e: + logger.error(f"Failed to post results to {url}: {e}") + return False + + +def _get_jobs(base_url: str, auth_token: str, pipeline_slug: str) -> list: + """Fetch job ids from the API for the given pipeline. + + Calls: GET {base_url}/api/v2/jobs?pipeline=&ids_only=1 + + Returns a list of job ids (possibly empty) on error. + """ + try: + url = f"{base_url.rstrip('/')}/api/v2/jobs" + params = {"pipeline__slug": pipeline_slug, "ids_only": 1, "incomplete_only": 1} + + headers = {} + if auth_token: + headers["Authorization"] = f"Token {auth_token}" + + resp = requests.get(url, params=params, headers=headers, timeout=30) + resp.raise_for_status() + data = resp.json() + + jobs = data.get("results") or [] + job_ids = [job["id"] for job in jobs] + if not isinstance(job_ids, list): + logger.warning(f"Unexpected job_ids format from {url}: {type(job_ids)}") + return [] + return job_ids + except requests.RequestException as e: + logger.error(f"Failed to fetch jobs from {base_url}: {e}") + return [] + + +def run_worker(pipelines: List[str]): + """Run the worker to process images from the REST API queue.""" + + base_url = os.environ.get("ANTENNA_API_BASE_URL", "http://localhost:8000") + auth_token = os.environ.get("ANTENNA_API_TOKEN", "") + # TODO CGJS: Support a list of pipelines + while True: + # TODO CGJS: Support pulling and prioritizing single image tasks, which are used in interactive testing + # These should probably come from a dedicated endpoint and should preempt batch jobs under the assumption that they + # would run on the same GPU. + any_jobs = False + for pipeline in pipelines: + logger.info(f"Checking for jobs for pipeline {pipeline}") + jobs = _get_jobs( + base_url=base_url, auth_token=auth_token, pipeline_slug=pipeline + ) + for job_id in jobs: + logger.info(f"Processing job {job_id} with pipeline {pipeline}") + any_work_done = _process_job( + pipeline=pipeline, + job_id=job_id, + base_url=base_url, + auth_token=auth_token, + ) + any_jobs = any_jobs or any_work_done + + if not any_jobs: + logger.info(f"No jobs found, sleeping for {SLEEP_TIME_SECONDS} seconds") + time.sleep(SLEEP_TIME_SECONDS) + + +@torch.no_grad() +def _process_job(pipeline: str, job_id: int, base_url: str, auth_token: str) -> bool: + """Run the worker to process images from the REST API queue. + + Args: + pipeline: Pipeline name to use for processing (e.g., moth_binary, panama_moths_2024) + job_id: Job ID to process + base_url: Base URL for the API + auth_token: API authentication token + Returns: + True if any work was done, False otherwise + """ + assert auth_token is not None, "ANTENNA_API_TOKEN environment variable not set" + did_work = False + loader = get_rest_dataloader( + job_id=job_id, base_url=base_url, auth_token=auth_token + ) + classifier = None + detector = None + + torch.cuda.empty_cache() + items = 0 + + total_detection_time = 0.0 + total_classification_time = 0.0 + total_save_time = 0.0 + total_dl_time = 0.0 + all_detections = [] + _, t = log_time() + + for i, batch in enumerate(loader): + dt, t = t("Finished loading batch") + total_dl_time += dt + if not batch: + logger.warning(f"Batch {i+1} is empty, skipping") + continue + + # Defer instantiation of detector and classifier until we have data + if not classifier: + classifier_class = CLASSIFIER_CHOICES[pipeline] + classifier = classifier_class(source_images=[], detections=[]) + detector = APIMothDetector([]) + assert detector is not None, "Detector not initialized" + assert classifier is not None, "Classifier not initialized" + detector.reset([]) + did_work = True + + # Extract data from dictionary batch + batch_input = batch.get("image", []) + item_ids = batch.get("image_id", []) + reply_subjects = batch.get("reply_subject", [None] * len(batch_input)) + image_urls = batch.get("image_url", [None] * len(batch_input)) + + # Track start time for this batch + batch_start_time = datetime.datetime.now() + + logger.info(f"Processing batch {i+1}") + # output is dict of "boxes", "labels", "scores" + batch_output = [] + if len(batch_input) > 0: + batch_output = detector.predict_batch(batch_input) + + items += len(batch_output) + logger.info(f"Total items processed so far: {items}") + batch_output = list(detector.post_process_batch(batch_output)) + + # Convert item_ids to list if needed + if isinstance(item_ids, (np.ndarray, torch.Tensor)): + item_ids = item_ids.tolist() + + # TODO CGJS: Add seconds per item calculation for both detector and classifier + detector.save_results( + item_ids=item_ids, + batch_output=batch_output, + seconds_per_item=0, + ) + dt, t = t("Finished detection") + total_detection_time += dt + + # Group detections by image_id + image_detections: dict[str, list[DetectionResponse]] = { + img_id: [] for img_id in item_ids + } + image_tensors = dict(zip(item_ids, batch_input)) + + classifier.reset(detector.results) + + for idx, dresp in enumerate(detector.results): + image_tensor = image_tensors[dresp.source_image_id] + bbox = dresp.bbox + # crop the image tensor using the bbox + crop = image_tensor[ + :, int(bbox.y1) : int(bbox.y2), int(bbox.x1) : int(bbox.x2) + ] + crop = crop.unsqueeze(0) # add batch dimension + classifier_out = classifier.predict_batch(crop) + classifier_out = classifier.post_process_batch(classifier_out) + detection = classifier.update_detection_classification( + seconds_per_item=0, + image_id=dresp.source_image_id, + detection_idx=idx, + predictions=classifier_out[0], + ) + image_detections[dresp.source_image_id].append(detection) + all_detections.append(detection) + + ct, t = t("Finished classification") + total_classification_time += ct + + # Calculate batch processing time + batch_end_time = datetime.datetime.now() + batch_elapsed = (batch_end_time - batch_start_time).total_seconds() + + # Post results back to the API with PipelineResponse for each image + batch_results = [] + for reply_subject, image_id, image_url in zip( + reply_subjects, item_ids, image_urls + ): + # Create SourceImageResponse for this image + source_image = SourceImageResponse(id=image_id, url=image_url) + + # Create PipelineResultsResponse + pipeline_response = PipelineResultsResponse( + pipeline=pipeline, + source_images=[source_image], + detections=image_detections[image_id], + total_time=batch_elapsed / len(item_ids), # Approximate time per image + ) + + batch_results.append( + { + "reply_subject": reply_subject, + "result": pipeline_response.model_dump(mode="json"), + } + ) + failed_items = batch.get("failed_items") + if failed_items: + for failed_item in failed_items: + batch_results.append( + { + "reply_subject": failed_item.get("reply_subject"), + # TODO CGJS: Should we extend PipelineResultsResponse to include errors? + "result": { + "error": failed_item.get("error", "Unknown error"), + "image_id": failed_item.get("image_id"), + }, + } + ) + + post_batch_results(base_url, job_id, batch_results, auth_token) + st, t = t("Finished posting results") + total_save_time += st + + logger.info( + f"Done, detections: {len(all_detections)}. Detecting time: {total_detection_time}, " + f"classification time: {total_classification_time}, dl time: {total_dl_time}, save time: {total_save_time}" + ) + return did_work diff --git a/trapdata/common/utils.py b/trapdata/common/utils.py index 15d11b2a..80c52966 100644 --- a/trapdata/common/utils.py +++ b/trapdata/common/utils.py @@ -1,9 +1,11 @@ import csv import datetime +import functools import pathlib import random import string -from typing import Any, Union +import time +from typing import Any, Callable, Tuple, Union def get_sequential_sample(direction, images, last_sample=None): @@ -119,3 +121,29 @@ def random_color(): color = [random.random() for _ in range(3)] color.append(0.8) # alpha return color + + +def log_time(start: float = 0, msg: str = None) -> Tuple[float, Callable]: + """ + Small helper to measure time between calls. + + Returns: elapsed time since the last call, and a partial function to measure from the current call + Usage: + + _, tlog = log_time() + # do something + _, tlog = tlog("Did something") # will log the time taken by 'something' + # do something else + t, tlog = tlog("Did something else") # will log the time taken by 'something else', returned as 't' + """ + from trapdata.common.logs import logger + + end = time.perf_counter() + if start == 0: + dur = 0.0 + else: + dur = end - start + if msg and start > 0: + logger.info(f"{msg}: {dur:.3f}s") + new_start = time.perf_counter() + return dur, functools.partial(log_time, new_start)