Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions examples/01_LocalBenchmark/run_tinyllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import inference_endpoint.config.rulesets.mlcommons.models as mlcommons_models
from inference_endpoint.config.rulesets.mlcommons.rules import CURRENT
from inference_endpoint.config.user_config import UserConfig
from inference_endpoint.core.types import QueryResult, StreamChunk
from inference_endpoint.core.types import QueryResult, StreamChunk, TextModelOutput
from inference_endpoint.dataset_manager.dataset import Dataset
from inference_endpoint.load_generator import (
BenchmarkSession,
Expand Down Expand Up @@ -167,10 +167,15 @@ def issue(self, sample):
)
SampleEventHandler.stream_chunk_complete(stream_chunk)
first = False
query_result = QueryResult(id=sample.uuid, response_output=chunks)
query_result = QueryResult(
id=sample.uuid,
response_output=TextModelOutput(output=chunks, reasoning=None),
)
else:
response = self.compute_func(sample.data)
query_result = QueryResult(id=sample.uuid, response_output=response)
query_result = QueryResult(
id=sample.uuid, response_output=TextModelOutput(output=response)
)
SampleEventHandler.query_result_complete(query_result)


Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ dependencies = [
"colorama==0.4.6",
# Fix pytz-2024 import warning
"pytz==2025.2",
# SQL event logger (swappable backends, default sqlite)
"sqlalchemy==2.0.48",
]

[project.optional-dependencies]
Expand Down
171 changes: 171 additions & 0 deletions src/inference_endpoint/async_utils/services/event_logger/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""EventPublisherService subscriber for logging event records.

Currently supported:
- JSONL file
- SQL database (SQLAlchemy; default sqlite, swappable backends)
"""

import argparse
import asyncio
import os
from pathlib import Path

from inference_endpoint.async_utils.loop_manager import LoopManager
from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext
from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqEventRecordSubscriber
from inference_endpoint.core.record import (
ErrorEventType,
EventRecord,
SessionEventType,
)

from .file_writer import JSONLWriter
from .sql_writer import SQLWriter
from .writer import RecordWriter

# CLI writer names to writer classes (for --writers flag)
_WRITER_REGISTRY: dict[str, type[RecordWriter]] = {
"jsonl": JSONLWriter,
"sql": SQLWriter,
}


def _is_error_event(record: EventRecord) -> bool:
"""True if the record is an error event (should not be dropped after ENDED)."""
return isinstance(record.event_type, ErrorEventType)


class EventLoggerService(ZmqEventRecordSubscriber):
"""Event logger service for logging event records.

When SessionEventType.ENDED is received (topic 'session.ended'), the service stops
accepting further events (except Error events), closes writers, and stops the event loop.
Writers are only closed after the current batch is fully processed, so error
events that appear in the same batch after ENDED are still written.
"""

def __init__(
self,
log_dir: Path,
*args,
writer_classes: tuple[type[RecordWriter], ...] = (JSONLWriter,),
flush_interval: int | None = 100,
shutdown_event: asyncio.Event | None = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self._shutdown_received = False
self._shutdown_event = shutdown_event

if not log_dir.exists():
log_dir.mkdir(parents=True, exist_ok=True)

if not log_dir.is_dir():
raise NotADirectoryError(f"Log directory {log_dir} is not a directory")

if not os.access(log_dir, os.W_OK):
raise PermissionError(f"Log directory {log_dir} is not writable")

self.writers: list[RecordWriter] = []
for writer_class in writer_classes:
self.writers.append(
writer_class(log_dir / "events", flush_interval=flush_interval)
)

def _write_record_to_writers(self, record: EventRecord) -> None:
"""Write a single record to all writers (uses write for flush-on-interval)."""
for writer in self.writers:
writer.write(record)

def _close_writers_and_stop(self) -> None:
"""Flush and close all writers, clear the list, then request loop stop."""
for writer in self.writers:
writer.flush()
writer.close()
self.writers.clear()
if self.loop is not None:
self.loop.call_soon_threadsafe(self._request_stop)

async def process(self, records: list[EventRecord]) -> None:
saw_shutdown = False
for record in records:
if self._shutdown_received and not _is_error_event(record):
continue
if record.event_type == SessionEventType.ENDED:
self._shutdown_received = True
saw_shutdown = True
self._write_record_to_writers(record)
if saw_shutdown:
self._close_writers_and_stop()

def _request_stop(self) -> None:
"""Close this subscriber and signal shutdown (or stop the loop if no shutdown_event)."""
self.close()
if self._shutdown_event is not None:
self._shutdown_event.set()
elif self.loop is not None and self.loop.is_running():
self.loop.stop()

def close(self) -> None:
for writer in self.writers:
writer.flush()
writer.close()
self.writers.clear()
super().close()


async def main() -> None:
parser = argparse.ArgumentParser(description="Event logger service")
parser.add_argument("--log-dir", type=Path, required=True, help="Log directory")
parser.add_argument(
"--socket-address",
type=str,
required=True,
help="ZMQ socket address to connect to",
)
parser.add_argument(
"--writers",
nargs="+",
choices=list(_WRITER_REGISTRY),
default=["jsonl"],
metavar="WRITER",
help="Writers to use: jsonl, sql (default: jsonl). Can specify multiple, e.g. --writers jsonl sql",
)
args = parser.parse_args()

writer_classes = tuple(_WRITER_REGISTRY[name] for name in args.writers)
shutdown_event = asyncio.Event()
loop = LoopManager().default_loop
with ManagedZMQContext.scoped(socket_dir=args.log_dir.parent) as zmq_ctx:
logger = EventLoggerService(
args.log_dir,
args.socket_address,
zmq_ctx,
loop,
topics=None, # Subscribe to all topics for logging
writer_classes=writer_classes,
shutdown_event=shutdown_event,
)

loop.call_soon_threadsafe(logger.start)
await shutdown_event.wait()


if __name__ == "__main__":
LoopManager().default_loop.run_until_complete(main())
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""JSONL file writer for event records.

If additional file-based writers are needed in the future, the shared file I/O
logic (open, flush, close, flush_interval) should be refactored out of
JSONLWriter into a ``StreamedFileWriter`` base class sitting between
``RecordWriter`` and the concrete writers.
"""

from pathlib import Path

import msgspec
from inference_endpoint.core.record import EventRecord, EventType

from .writer import RecordWriter


class JSONLWriter(RecordWriter):
"""Writes event records to a JSONL file."""

extension = ".jsonl"

def __init__(
self,
file_path: Path,
mode: str = "w",
flush_interval: int | None = None,
**kwargs: object,
):
super().__init__(flush_interval=flush_interval)
self.file_path = Path(file_path).with_suffix(self.extension)
self.file_obj = self.file_path.open(mode=mode) # type: ignore[assignment]
self.encoder = msgspec.json.Encoder(enc_hook=EventType.encode_hook)

def _write_record(self, record: EventRecord) -> None:
if self.file_obj is not None:
self.file_obj.write(self.encoder.encode(record).decode("utf-8") + "\n")

def flush(self) -> None:
if self.file_obj is not None:
self.file_obj.flush()
super().flush()

def close(self) -> None:
if self.file_obj is not None:
try:
self.flush()
self.file_obj.close()
except (OSError, FileNotFoundError):
pass
finally:
self.file_obj = None # type: ignore[assignment]
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""SQL writer for event records using SQLAlchemy (swappable SQL backends, default sqlite)."""

from pathlib import Path

import msgspec
from inference_endpoint.core.record import EventRecord
from sqlalchemy import BigInteger, Integer, LargeBinary, String, create_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, sessionmaker

from .writer import RecordWriter


class Base(DeclarativeBase):
"""Declarative base for event logger SQL models."""

pass


class EventRowModel(Base):
"""SQLAlchemy model for event rows.

Schema aligned with metrics/recorder.EventRow but uses EventType topic strings
(e.g. 'session.ended', 'sample.complete') for event_type instead of legacy Event enum values.
"""

__tablename__ = "events"

id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
sample_uuid: Mapped[str] = mapped_column(String, nullable=False, default="")
"""UUID string identifier for the sample."""

event_type: Mapped[str] = mapped_column(String, nullable=False)
"""Event type as topic string (e.g. 'session.ended', 'sample.complete')."""

timestamp_ns: Mapped[int] = mapped_column(BigInteger, nullable=False)
"""Monotonic timestamp in nanoseconds."""

data: Mapped[bytes] = mapped_column(LargeBinary, nullable=False, default=b"")
"""JSON-encoded event data."""


def _record_to_row(record: EventRecord) -> EventRowModel:
# event_type.topic is set by EventTypeMeta on each enum member
topic = record.event_type.topic # type: ignore[attr-defined]
return EventRowModel(
sample_uuid=record.sample_uuid,
event_type=topic,
timestamp_ns=record.timestamp_ns,
data=msgspec.json.encode(record.data),
)


class SQLWriter(RecordWriter):
"""Writes event records to a SQL database via SQLAlchemy.

Uses SQLAlchemy so the backend can be swapped (e.g. sqlite, postgresql).
Default URL is sqlite at the given path with .db suffix.
"""

def __init__(
self,
path: Path,
url: str | None = None,
flush_interval: int | None = None,
**kwargs: object,
):
"""Initialize the SQL writer.

Args:
path: Base path for the database. For sqlite default, the file will be path.with_suffix(".db").
url: Optional SQLAlchemy database URL. If None, uses sqlite at path.with_suffix(".db").
flush_interval: If set, flush (commit) after every this many records.
"""
super().__init__(flush_interval=flush_interval)
if url is None:
db_path = Path(path).with_suffix(".db")
url = f"sqlite:///{db_path}"
self._engine = create_engine(url)
Base.metadata.create_all(self._engine)
self._session_factory = sessionmaker(
bind=self._engine, autoflush=False, expire_on_commit=False
)
self._session = self._session_factory()

def _write_record(self, record: EventRecord) -> None:
if self._session is None:
return
row = _record_to_row(record)
self._session.add(row)

def flush(self) -> None:
if self._session is not None:
self._session.commit()
super().flush()

def close(self) -> None:
if self._session is not None:
try:
self.flush()
self._session.close()
finally:
self._session = None
if self._engine is not None:
self._engine.dispose()
self._engine = None
Loading
Loading