diff --git a/README.md b/README.md
index 1796d38..f1919c9 100644
--- a/README.md
+++ b/README.md
@@ -71,7 +71,7 @@ outputs = pipe(
messages,
max_new_tokens=256,
)
-print(outputs[0]["generated_text"][-1])
+print(outputs[0]["generated_text"][-1]['content'])
```
[Learn more about how to use gpt-oss with Transformers.](https://cookbook.openai.com/articles/gpt-oss/run-transformers)
@@ -253,7 +253,7 @@ hf download openai/gpt-oss-20b --include "metal/*" --local-dir gpt-oss-20b/metal
To test it you can run:
```shell
-python gpt_oss/metal/examples/generate.py gpt-oss-20b/metal/model.bin -p "why did the chicken cross the road?"
+python gpt_oss/metal/examples/generate.py gpt-oss-20b/metal/metal/model.bin -p "why did the chicken cross the road?"
```
## Harmony format & tools
diff --git a/_build/gpt_oss_build_backend/backend.py b/_build/gpt_oss_build_backend/backend.py
index 5cd76bd..d0077ef 100644
--- a/_build/gpt_oss_build_backend/backend.py
+++ b/_build/gpt_oss_build_backend/backend.py
@@ -29,112 +29,550 @@
`build-backend = "gpt_oss_build_backend.backend"` in pyproject.toml.
"""
import os
+import sys
from importlib import import_module
-from typing import Any, Mapping, Sequence
-
-
-TRUE_VALUES = {"1", "true", "TRUE", "on", "ON", "yes", "YES"}
+from typing import Any, Callable, List, Mapping, Optional, Protocol, Sequence, Union
+
+
+# Configuration constants
+ENV_VAR_METAL_BUILD = "GPTOSS_BUILD_METAL"
+TRUE_VALUES = {"1", "true", "on", "yes", "y", "t"}
+
+# Build requirements
+METAL_BUILD_REQUIREMENTS = [
+ "scikit-build-core>=0.10",
+ "pybind11>=2.12",
+ "cmake>=3.26",
+ "ninja",
+]
+
+SETUPTOOLS_BUILD_REQUIREMENTS: List[str] = []
+
+
+class BuildBackendProtocol(Protocol):
+ """Protocol defining the interface for build backends."""
+
+ def build_wheel(
+ self,
+ wheel_directory: str,
+ config_settings: Optional[Mapping[str, Any]] = None,
+ metadata_directory: Optional[str] = None,
+ ) -> str:
+ """Build a wheel and return its filename."""
+ ...
+
+ def build_sdist(
+ self,
+ sdist_directory: str,
+ config_settings: Optional[Mapping[str, Any]] = None,
+ ) -> str:
+ """Build a source distribution and return its filename."""
+ ...
+
+ def get_requires_for_build_wheel(
+ self,
+ config_settings: Optional[Mapping[str, Any]] = None,
+ ) -> List[str]:
+ """Get requirements for building a wheel."""
+ ...
+
+ def get_requires_for_build_sdist(
+ self,
+ config_settings: Optional[Mapping[str, Any]] = None,
+ ) -> List[str]:
+ """Get requirements for building a source distribution."""
+ ...
+
+
+class BuildError(Exception):
+ """Custom exception for build-related errors."""
+ pass
+
+
+class ConfigurationError(Exception):
+ """Exception raised for configuration-related errors."""
+ pass
+
+
+def _log_info(message: str) -> None:
+ """
+ Simple logging function that prints to stderr with prefix.
+
+ Args:
+ message: The message to log.
+ """
+ print(f"[gpt-oss-build] {message}", file=sys.stderr)
+
+
+def _log_error(message: str) -> None:
+ """
+ Log error messages to stderr with prefix.
+
+ Args:
+ message: The error message to log.
+ """
+ print(f"[gpt-oss-build] ERROR: {message}", file=sys.stderr)
+
+
+def _validate_directory(directory: str, operation: str) -> None:
+ """
+ Validate that a directory exists and is writable.
+
+ Args:
+ directory: Path to the directory to validate.
+ operation: Description of the operation for error messages.
+
+ Raises:
+ ConfigurationError: If directory validation fails.
+ """
+ if not directory:
+ raise ConfigurationError(f"Directory for {operation} cannot be empty")
+
+ directory_path = os.path.abspath(directory)
+
+ # Create directory if it doesn't exist
+ try:
+ os.makedirs(directory_path, exist_ok=True)
+ except OSError as e:
+ raise ConfigurationError(f"Cannot create directory {directory_path} for {operation}: {e}")
+
+ # Check if directory is writable
+ if not os.access(directory_path, os.W_OK):
+ raise ConfigurationError(f"Directory {directory_path} is not writable for {operation}")
+
+
+def _parse_environment_variable(var_name: str, default: str = "") -> str:
+ """
+ Parse and normalize an environment variable value.
+
+ Args:
+ var_name: Name of the environment variable.
+ default: Default value if the variable is not set.
+
+ Returns:
+ Normalized environment variable value.
+ """
+ value = os.environ.get(var_name, default).strip().lower()
+ _log_info(f"Environment variable {var_name}='{os.environ.get(var_name, default)}' -> normalized: '{value}'")
+ return value
def _use_metal_backend() -> bool:
- return str(os.environ.get("GPTOSS_BUILD_METAL", "")).strip() in TRUE_VALUES
-
-
-def _setuptools_backend():
- from setuptools import build_meta as _bm # type: ignore
-
- return _bm
-
-
-def _scikit_build_backend():
- return import_module("scikit_build_core.build")
-
-
-def _backend():
- return _scikit_build_backend() if _use_metal_backend() else _setuptools_backend()
+ """
+ Check if Metal backend should be used based on GPTOSS_BUILD_METAL environment variable.
+
+ This function performs case-insensitive, whitespace-tolerant parsing of the
+ environment variable and logs the decision for transparency.
+
+ Returns:
+ bool: True if Metal backend should be used, False otherwise.
+
+ Raises:
+ ConfigurationError: If there's an issue with configuration validation.
+ """
+ try:
+ env_value = _parse_environment_variable(ENV_VAR_METAL_BUILD)
+ use_metal = env_value in TRUE_VALUES
+
+ if env_value:
+ backend_type = "Metal (scikit-build-core)" if use_metal else "setuptools (invalid value)"
+ _log_info(f"Backend selection: {backend_type}")
+ else:
+ _log_info("Backend selection: setuptools (default - no environment variable set)")
+
+ return use_metal
+ except Exception as e:
+ _log_error(f"Failed to determine backend from environment: {e}")
+ _log_info("Falling back to setuptools backend")
+ return False
+
+
+def _setuptools_backend() -> BuildBackendProtocol:
+ """
+ Get the setuptools build backend.
+
+ Returns:
+ BuildBackendProtocol: The setuptools build backend instance.
+
+ Raises:
+ BuildError: If setuptools backend cannot be imported.
+ """
+ try:
+ from setuptools import build_meta as _bm # type: ignore
+ _log_info("Successfully imported setuptools build backend")
+ return _bm
+ except ImportError as e:
+ _log_error(f"Failed to import setuptools build backend: {e}")
+ raise BuildError("setuptools is required but not available") from e
+
+
+def _scikit_build_backend() -> BuildBackendProtocol:
+ """
+ Get the scikit-build-core build backend.
+
+ Returns:
+ BuildBackendProtocol: The scikit-build-core build backend instance.
+
+ Raises:
+ BuildError: If scikit-build-core backend cannot be imported.
+ """
+ try:
+ backend = import_module("scikit_build_core.build")
+ _log_info("Successfully imported scikit-build-core backend")
+ return backend
+ except ImportError as e:
+ _log_error(f"Failed to import scikit-build-core: {e}")
+ _log_error("Install Metal build dependencies with: pip install 'scikit-build-core>=0.10' cmake ninja")
+ raise BuildError("scikit-build-core is required for Metal backend but not available") from e
+
+
+def _backend() -> BuildBackendProtocol:
+ """
+ Get the appropriate build backend based on configuration.
+
+ This function selects between setuptools and scikit-build-core backends
+ based on the environment configuration, with comprehensive error handling.
+
+ Returns:
+ BuildBackendProtocol: The selected build backend instance.
+
+ Raises:
+ BuildError: If the selected backend cannot be loaded.
+ ConfigurationError: If there's a configuration issue.
+ """
+ try:
+ if _use_metal_backend():
+ _log_info("Configuring scikit-build-core backend for Metal extension")
+ return _scikit_build_backend()
+ else:
+ _log_info("Configuring setuptools backend for pure Python wheel")
+ return _setuptools_backend()
+ except (BuildError, ConfigurationError):
+ # Re-raise known exceptions
+ raise
+ except Exception as e:
+ _log_error(f"Unexpected error while selecting build backend: {e}")
+ raise BuildError(f"Failed to configure build backend: {e}") from e
+
+
+def _safe_getattr(obj: Any, name: str, default: Optional[Callable] = None) -> Optional[Callable]:
+ """
+ Safely get an attribute from an object with logging.
+
+ Args:
+ obj: The object to get the attribute from.
+ name: The name of the attribute.
+ default: Default value if attribute doesn't exist.
+
+ Returns:
+ The attribute value or default.
+ """
+ attr = getattr(obj, name, default)
+ if attr is None:
+ _log_info(f"Backend doesn't implement {name}")
+ else:
+ _log_info(f"Backend supports {name}")
+ return attr
# Required PEP 517 hooks
def build_wheel(
wheel_directory: str,
- config_settings: Mapping[str, Any] | None = None,
- metadata_directory: str | None = None,
+ config_settings: Optional[Mapping[str, Any]] = None,
+ metadata_directory: Optional[str] = None,
) -> str:
- return _backend().build_wheel(wheel_directory, config_settings, metadata_directory)
+ """
+ Build a wheel in the specified directory.
+
+ Args:
+ wheel_directory: Directory where the wheel should be built.
+ config_settings: Optional configuration settings.
+ metadata_directory: Optional metadata directory.
+
+ Returns:
+ str: The filename of the built wheel.
+
+ Raises:
+ BuildError: If the wheel build fails.
+ ConfigurationError: If directory validation fails.
+ """
+ try:
+ _validate_directory(wheel_directory, "wheel building")
+ _log_info(f"Building wheel in directory: {wheel_directory}")
+
+ backend = _backend()
+ result = backend.build_wheel(wheel_directory, config_settings, metadata_directory)
+
+ _log_info(f"Successfully built wheel: {result}")
+ return result
+ except (BuildError, ConfigurationError):
+ raise
+ except Exception as e:
+ _log_error(f"Wheel build failed: {e}")
+ raise BuildError(f"Failed to build wheel: {e}") from e
def build_sdist(
- sdist_directory: str, config_settings: Mapping[str, Any] | None = None
+ sdist_directory: str,
+ config_settings: Optional[Mapping[str, Any]] = None
) -> str:
- return _backend().build_sdist(sdist_directory, config_settings)
+ """
+ Build a source distribution in the specified directory.
+
+ Args:
+ sdist_directory: Directory where the sdist should be built.
+ config_settings: Optional configuration settings.
+
+ Returns:
+ str: The filename of the built source distribution.
+
+ Raises:
+ BuildError: If the sdist build fails.
+ ConfigurationError: If directory validation fails.
+ """
+ try:
+ _validate_directory(sdist_directory, "source distribution building")
+ _log_info(f"Building sdist in directory: {sdist_directory}")
+
+ backend = _backend()
+ result = backend.build_sdist(sdist_directory, config_settings)
+
+ _log_info(f"Successfully built sdist: {result}")
+ return result
+ except (BuildError, ConfigurationError):
+ raise
+ except Exception as e:
+ _log_error(f"Sdist build failed: {e}")
+ raise BuildError(f"Failed to build sdist: {e}") from e
def prepare_metadata_for_build_wheel(
- metadata_directory: str, config_settings: Mapping[str, Any] | None = None
+ metadata_directory: str,
+ config_settings: Optional[Mapping[str, Any]] = None
) -> str:
- # Fallback if backend doesn't implement it
- be = _backend()
- fn = getattr(be, "prepare_metadata_for_build_wheel", None)
- if fn is None:
- # setuptools exposes it; scikit-build-core may not. Defer to building a wheel for metadata.
- return _setuptools_backend().prepare_metadata_for_build_wheel(
- metadata_directory, config_settings
- )
- return fn(metadata_directory, config_settings)
+ """
+ Prepare metadata for wheel building.
+
+ Args:
+ metadata_directory: Directory where metadata should be prepared.
+ config_settings: Optional configuration settings.
+
+ Returns:
+ str: The name of the metadata directory.
+
+ Raises:
+ BuildError: If metadata preparation fails.
+ ConfigurationError: If directory validation fails.
+ """
+ try:
+ _validate_directory(metadata_directory, "metadata preparation")
+ _log_info(f"Preparing metadata in directory: {metadata_directory}")
+
+ backend = _backend()
+ fn = _safe_getattr(backend, "prepare_metadata_for_build_wheel")
+
+ if fn is None:
+ # Fallback to setuptools if current backend doesn't support it
+ _log_info("Using setuptools fallback for metadata preparation")
+ setuptools_backend = _setuptools_backend()
+ result = setuptools_backend.prepare_metadata_for_build_wheel(
+ metadata_directory, config_settings
+ )
+ else:
+ result = fn(metadata_directory, config_settings)
+
+ _log_info(f"Successfully prepared metadata: {result}")
+ return result
+ except (BuildError, ConfigurationError):
+ raise
+ except Exception as e:
+ _log_error(f"Metadata preparation failed: {e}")
+ raise BuildError(f"Failed to prepare metadata: {e}") from e
# Optional hooks
def build_editable(
- editable_directory: str, config_settings: Mapping[str, Any] | None = None, metadata_directory: str | None = None
+ main
+ wheel_directory: str,
+ config_settings: Optional[Mapping[str, Any]] = None,
+ metadata_directory: Optional[str] = None,
+
) -> str:
- be = _backend()
- fn = getattr(be, "build_editable", None)
- if fn is None:
- # setuptools implements build_editable; if not available, raise the standard error
- raise RuntimeError("Editable installs not supported by the selected backend")
- return fn(editable_directory, config_settings)
+ """
+ Build an editable wheel in the specified directory.
+
+ Args:
+ wheel_directory: Directory where the editable wheel should be built.
+ config_settings: Optional configuration settings.
+ metadata_directory: Optional metadata directory.
+
+ Returns:
+ str: The filename of the built editable wheel.
+
+ Raises:
+ BuildError: If the editable build fails or is not supported.
+ ConfigurationError: If directory validation fails.
+ """
+ try:
+ _validate_directory(wheel_directory, "editable wheel building")
+ _log_info(f"Building editable install in directory: {wheel_directory}")
+
+ backend = _backend()
+ fn = _safe_getattr(backend, "build_editable")
+
+ if fn is None:
+ raise BuildError("Editable installs not supported by the selected backend")
+
+ result = fn(wheel_directory, config_settings, metadata_directory)
+ _log_info(f"Successfully built editable wheel: {result}")
+ return result
+ except (BuildError, ConfigurationError):
+ raise
+ except Exception as e:
+ _log_error(f"Editable build failed: {e}")
+ raise BuildError(f"Failed to build editable wheel: {e}") from e
def get_requires_for_build_wheel(
- config_settings: Mapping[str, Any] | None = None,
+ config_settings: Optional[Mapping[str, Any]] = None,
) -> Sequence[str]:
- if _use_metal_backend():
- # Add dynamic build requirements only when building the Metal backend
- return [
- "scikit-build-core>=0.10",
- "pybind11>=2.12",
- "cmake>=3.26",
- "ninja",
- ]
- # setuptools usually returns []
- return list(_setuptools_backend().get_requires_for_build_wheel(config_settings))
+ """
+ Get the build requirements for building a wheel.
+
+ Args:
+ config_settings: Optional configuration settings.
+
+ Returns:
+ Sequence[str]: List of build requirements.
+
+ Raises:
+ ConfigurationError: If there's a configuration issue.
+ """
+ try:
+ if _use_metal_backend():
+ requirements = METAL_BUILD_REQUIREMENTS.copy()
+ _log_info(f"Metal backend requires: {', '.join(requirements)}")
+ return requirements
+ else:
+ # Get requirements from setuptools backend
+ setuptools_backend = _setuptools_backend()
+ fn = _safe_getattr(setuptools_backend, "get_requires_for_build_wheel")
+
+ if fn is not None:
+ requirements = list(fn(config_settings))
+ else:
+ requirements = SETUPTOOLS_BUILD_REQUIREMENTS.copy()
+
+ _log_info(f"Setuptools backend requires: {', '.join(requirements) if requirements else 'no additional packages'}")
+ return requirements
+ except Exception as e:
+ _log_error(f"Failed to get build requirements: {e}")
+ raise ConfigurationError(f"Unable to determine build requirements: {e}") from e
def get_requires_for_build_sdist(
- config_settings: Mapping[str, Any] | None = None,
+ config_settings: Optional[Mapping[str, Any]] = None,
) -> Sequence[str]:
- # No special requirements for SDist
- be = _backend()
- fn = getattr(be, "get_requires_for_build_sdist", None)
- if fn is None:
- return []
- return list(fn(config_settings))
+ """
+ Get the build requirements for building a source distribution.
+
+ Args:
+ config_settings: Optional configuration settings.
+
+ Returns:
+ Sequence[str]: List of build requirements.
+
+ Raises:
+ ConfigurationError: If there's a configuration issue.
+ """
+ try:
+ backend = _backend()
+ fn = _safe_getattr(backend, "get_requires_for_build_sdist")
+
+ if fn is None:
+ requirements: List[str] = []
+ else:
+ requirements = list(fn(config_settings))
+
+ _log_info(f"SDist requires: {', '.join(requirements) if requirements else 'no additional packages'}")
+ return requirements
+ except Exception as e:
+ _log_error(f"Failed to get sdist requirements: {e}")
+ raise ConfigurationError(f"Unable to determine sdist requirements: {e}") from e
def get_requires_for_build_editable(
- config_settings: Mapping[str, Any] | None = None,
+ config_settings: Optional[Mapping[str, Any]] = None,
+) -> Sequence[str]:
+ """
+ Get the build requirements for building an editable install.
+
+ Args:
+ config_settings: Optional configuration settings.
+
+ Returns:
+ Sequence[str]: List of build requirements.
+
+ Raises:
+ ConfigurationError: If there's a configuration issue.
+ """
+ try:
+ if _use_metal_backend():
+ requirements = METAL_BUILD_REQUIREMENTS.copy()
+ _log_info(f"Editable Metal backend requires: {', '.join(requirements)}")
+ return requirements
+ else:
+ setuptools_backend = _setuptools_backend()
+ fn = _safe_getattr(setuptools_backend, "get_requires_for_build_editable")
+
+ if fn is None:
+ requirements: List[str] = []
+ else:
+ requirements = list(fn(config_settings))
+
+ _log_info(f"Editable setuptools backend requires: {', '.join(requirements) if requirements else 'no additional packages'}")
+ return requirements
+ except Exception as e:
+ _log_error(f"Failed to get editable build requirements: {e}")
+ raise ConfigurationError(f"Unable to determine editable build requirements: {e}") from e
+
+
+# Future expansion hooks (currently unused but available for extension)
+
+def get_requires_for_build_meta(
+ config_settings: Optional[Mapping[str, Any]] = None,
) -> Sequence[str]:
- if _use_metal_backend():
- return [
- "scikit-build-core>=0.10",
- "pybind11>=2.12",
- "cmake>=3.26",
- "ninja",
- ]
- be = _setuptools_backend()
- fn = getattr(be, "get_requires_for_build_editable", None)
- if fn is None:
- return []
- return list(fn(config_settings))
\ No newline at end of file
+ """
+ Get requirements for building metadata (future PEP extension).
+
+ Args:
+ config_settings: Optional configuration settings.
+
+ Returns:
+ Sequence[str]: List of requirements (currently empty).
+ """
+ _log_info("get_requires_for_build_meta called (future extension)")
+ return []
+
+
+def build_meta(
+ meta_directory: str,
+ config_settings: Optional[Mapping[str, Any]] = None,
+) -> str:
+ """
+ Build metadata only (future PEP extension).
+
+ Args:
+ meta_directory: Directory for metadata.
+ config_settings: Optional configuration settings.
+
+ Returns:
+ str: Metadata filename.
+
+ Raises:
+ BuildError: Always, as this is not yet implemented.
+ """
+ _log_info("build_meta called (future extension)")
+ raise BuildError("build_meta is not yet implemented")
\ No newline at end of file
diff --git a/gpt_oss/chat.py b/gpt_oss/chat.py
index 5e40079..8a749cf 100644
--- a/gpt_oss/chat.py
+++ b/gpt_oss/chat.py
@@ -2,9 +2,9 @@
Harmony chat with tools
"""
-import atexit
import argparse
import asyncio
+import atexit
import datetime
import os
from pathlib import Path
@@ -14,14 +14,8 @@
except ImportError:
import readline
-import torch
import termcolor
-
-from gpt_oss.tools import apply_patch
-from gpt_oss.tools.simple_browser import SimpleBrowserTool
-from gpt_oss.tools.simple_browser.backend import ExaBackend
-from gpt_oss.tools.python_docker.docker_tool import PythonTool
-
+import torch
from openai_harmony import (
Author,
Conversation,
@@ -38,6 +32,10 @@
load_harmony_encoding,
)
+from gpt_oss.tools import apply_patch
+from gpt_oss.tools.python_docker.docker_tool import PythonTool
+from gpt_oss.tools.simple_browser import SimpleBrowserTool
+from gpt_oss.tools.simple_browser.backend import ExaBackend
REASONING_EFFORT = {
"high": ReasoningEffort.HIGH,
@@ -49,7 +47,10 @@
def get_user_input():
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
if rank == 0:
- user_input = input()
+ try:
+ user_input = input()
+ except EOFError:
+ return "/exit"
else:
user_input = ""
user_input_list = [user_input]
@@ -59,35 +60,42 @@ def get_user_input():
def main(args):
+ # --- Backend & generator init
match args.backend:
case "triton":
- from gpt_oss.triton.model import TokenGenerator as TritonGenerator
from gpt_oss.torch.utils import init_distributed
+ from gpt_oss.triton.model import TokenGenerator as TritonGenerator
+
device = init_distributed()
generator = TritonGenerator(args.checkpoint, args.context, device)
case "torch":
from gpt_oss.torch.model import TokenGenerator as TorchGenerator
from gpt_oss.torch.utils import init_distributed
+
device = init_distributed()
generator = TorchGenerator(args.checkpoint, device)
case "vllm":
from gpt_oss.vllm.token_generator import TokenGenerator as VLLMGenerator
- generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=2)
+
+ generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=args.tp_size)
case _:
raise ValueError(f"Invalid backend: {args.backend}")
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
+ # --- System message
system_message_content = (
SystemContent.new()
.with_reasoning_effort(REASONING_EFFORT[args.reasoning_effort])
.with_conversation_start_date(datetime.datetime.now().strftime("%Y-%m-%d"))
)
+ # Tools
+ browser_tool = None
+ python_tool = None
+
if args.browser:
- backend = ExaBackend(
- source="web",
- )
+ backend = ExaBackend(source="web")
browser_tool = SimpleBrowserTool(backend=backend)
system_message_content = system_message_content.with_tools(browser_tool.tool_config)
@@ -98,34 +106,35 @@ def main(args):
system_message = Message.from_role_and_content(Role.SYSTEM, system_message_content)
messages = [system_message]
+ # --- Developer message (resolved)
+ developer_message_content = None
if args.apply_patch:
apply_patch_instructions = Path(apply_patch.__file__).parent / "apply_patch.md"
- developer_message = ""
- if args.developer_message:
- developer_message = args.developer_message + "\n"
+ developer_message = (args.developer_message + "\n") if args.developer_message else ""
developer_message += apply_patch_instructions.read_text()
developer_message_content = (
DeveloperContent.new()
.with_instructions(developer_message)
- .with_function_tools([
- ToolDescription.new(
- "apply_patch",
- "Patch a file",
- parameters={
- "type": "string",
- "description": "Formatted patch code",
- "default": "*** Begin Patch\n*** End Patch\n",
- }
- ),
- ])
+ .with_function_tools(
+ [
+ ToolDescription.new(
+ "apply_patch",
+ "Patch a file",
+ parameters={
+ "type": "string",
+ "description": "Formatted patch code",
+ "default": "*** Begin Patch\n*** End Patch\n",
+ },
+ ),
+ ]
+ )
)
messages.append(Message.from_role_and_content(Role.DEVELOPER, developer_message_content))
elif args.developer_message:
developer_message_content = DeveloperContent.new().with_instructions(args.developer_message)
messages.append(Message.from_role_and_content(Role.DEVELOPER, developer_message_content))
- else:
- developer_message_content = None
+ # --- Raw vs pretty headers
if args.raw:
conversation = Conversation.from_messages(messages)
tokens = encoding.render_conversation(conversation)
@@ -135,7 +144,6 @@ def main(args):
user_message_start = encoding.decode(empty_user_message_tokens[:-1])
user_message_end = encoding.decode(empty_user_message_tokens[-1:])
else:
- # System message
print(termcolor.colored("System Message:", "cyan"), flush=True)
print(termcolor.colored("Model Identity:", "cyan"), system_message_content.model_identity, flush=True)
print(termcolor.colored("Reasoning Effort:", "cyan"), system_message_content.reasoning_effort, flush=True)
@@ -148,138 +156,159 @@ def main(args):
print(termcolor.colored("Developer Message:", "yellow"), flush=True)
print(developer_message_content.instructions, flush=True)
- # Print the system message and the user message start
+ # --- REPL
MESSAGE_PADDING = 12
- while True:
- last_message = messages[-1]
- if last_message.recipient is None:
- if args.raw:
- print(user_message_start, end="", flush=True)
- user_message = get_user_input()
- print(user_message_end, flush=True, end="")
+ try:
+ while True:
+ last_message = messages[-1]
+ if last_message.recipient is None:
+ # Get user input
+ if args.raw:
+ print(user_message_start, end="", flush=True)
+ user_text = get_user_input()
+ print(user_message_end, flush=True, end="")
+ else:
+ print(termcolor.colored("User:".ljust(MESSAGE_PADDING), "red"), flush=True)
+ user_text = get_user_input()
+
+ if user_text.strip().lower() in {"/exit", "exit", ":q", "quit"}:
+ print(termcolor.colored("Bye!", "cyan"))
+ break
+
+ user_message = Message.from_role_and_content(Role.USER, user_text)
+ messages.append(user_message)
else:
- print(termcolor.colored("User:".ljust(MESSAGE_PADDING), "red"), flush=True)
- user_message = get_user_input()
- user_message = Message.from_role_and_content(Role.USER, user_message)
- messages.append(user_message)
- else:
- # Tool or function call
- if last_message.recipient.startswith("browser."):
- assert args.browser, "Browser tool is not enabled"
- tool_name = "Search"
- async def run_tool():
- results = []
- async for msg in browser_tool.process(last_message):
- results.append(msg)
- return results
-
- result = asyncio.run(run_tool())
- messages += result
- elif last_message.recipient.startswith("python"):
- assert args.python, "Python tool is not enabled"
- tool_name = "Python"
- async def run_tool():
- results = []
- async for msg in python_tool.process(last_message):
- results.append(msg)
- return results
-
- result = asyncio.run(run_tool())
- messages += result
- elif last_message.recipient == "functions.apply_patch":
- assert args.apply_patch, "Apply patch tool is not enabled"
- tool_name = "Apply Patch"
- text = last_message.content[0].text
- tool_output = None
-
- if text.startswith("{"):
- # this is json, try to extract the patch from it
- import json
- try:
- some_dict = json.loads(text)
- _, text = some_dict.popitem()
- except Exception as e:
- tool_output = f"Error parsing JSON: {e}"
-
- if tool_output is None:
- try:
- tool_output = apply_patch.apply_patch(text)
- except Exception as e:
- tool_output = f"Error applying patch: {e}"
-
- message = (
- Message(
- author=Author.new(Role.TOOL, last_message.recipient),
- content=[TextContent(text=tool_output)]
+ # Tool/function dispatch
+ if last_message.recipient.startswith("browser."):
+ assert args.browser, "Browser tool is not enabled"
+ tool_name = "Search"
+
+ async def run_tool():
+ results = []
+ async for msg in browser_tool.process(last_message):
+ results.append(msg)
+ return results
+
+ result = asyncio.run(run_tool())
+ messages += result
+ elif last_message.recipient.startswith("python"):
+ assert args.python, "Python tool is not enabled"
+ tool_name = "Python"
+
+ async def run_tool():
+ results = []
+ async for msg in python_tool.process(last_message):
+ results.append(msg)
+ return results
+
+ result = asyncio.run(run_tool())
+ messages += result
+ elif last_message.recipient == "functions.apply_patch":
+ assert args.apply_patch, "Apply patch tool is not enabled"
+ tool_name = "Apply Patch"
+ text = last_message.content[0].text
+ tool_output = None
+
+ if text.startswith("{"):
+ import json
+ try:
+ some_dict = json.loads(text)
+ _, text = some_dict.popitem()
+ except Exception as e:
+ tool_output = f"Error parsing JSON: {e}"
+
+ if tool_output is None:
+ try:
+ tool_output = apply_patch.apply_patch(text)
+ except Exception as e:
+ tool_output = f"Error applying patch: {e}"
+
+ message = (
+ Message(
+ author=Author.new(Role.TOOL, last_message.recipient),
+ content=[TextContent(text=tool_output)],
+ )
+ .with_recipient("assistant")
)
- .with_recipient("assistant")
- )
- if last_message.channel:
- message = message.with_channel(last_message.channel)
+ if last_message.channel:
+ message = message.with_channel(last_message.channel)
- result = [message]
- messages += result
- else:
- raise ValueError(f"Unknown tool or function call: {last_message.recipient}")
- # Print the tool or function call result
- if args.raw:
- rendered_result = encoding.render_conversation(Conversation.from_messages(result))
- print(encoding.decode(rendered_result), flush=True, end="")
- else:
- print(termcolor.colored(f"{tool_name} output:".ljust(MESSAGE_PADDING), "magenta"), flush=True)
- if tool_name == "Search" and not args.show_browser_results:
- print("[Search results fed to the model]")
+ result = [message]
+ messages += result
else:
- print(result[0].content[0].text)
+ raise ValueError(f"Unknown tool or function call: {last_message.recipient}")
- conversation = Conversation.from_messages(messages)
- tokens = encoding.render_conversation_for_completion(
- conversation, Role.ASSISTANT
- )
-
- if args.raw:
- # Print the last two tokens, which are the start of the assistant message
- print(encoding.decode(tokens[-2:]), flush=True, end="")
+ # Print tool output
+ if args.raw:
+ rendered_result = encoding.render_conversation(Conversation.from_messages(result))
+ print(encoding.decode(rendered_result), flush=True, end="")
+ else:
+ print(termcolor.colored(f"{tool_name} output:".ljust(MESSAGE_PADDING), "magenta"), flush=True)
+ if tool_name == "Search" and not args.show_browser_results:
+ print("[Search results fed to the model]")
+ else:
+ # Some tools may return multiple messages; show first non-empty safely
+ first_text = ""
+ for m in result:
+ if m.content and m.content[0].text:
+ first_text = m.content[0].text
+ break
+ print(first_text)
+
+ # --- Model completion stream
+ conversation = Conversation.from_messages(messages)
+ tokens = encoding.render_conversation_for_completion(conversation, Role.ASSISTANT)
- parser = StreamableParser(encoding, role=Role.ASSISTANT)
- field_created = False
- current_output_text = ""
- output_text_delta_buffer = ""
- for predicted_token in generator.generate(tokens, encoding.stop_tokens_for_assistant_actions()):
- parser.process(predicted_token)
if args.raw:
- print(encoding.decode([predicted_token]), end="", flush=True)
- continue
-
- if parser.state == StreamState.EXPECT_START:
- print("") # new line
- field_created = False
-
- if not parser.last_content_delta:
- continue
-
- if not field_created:
- field_created = True
- if parser.current_channel == "final":
- print(termcolor.colored("Assistant:", "green"), flush=True)
- elif parser.current_recipient is not None:
- print(termcolor.colored(f"Tool call to {parser.current_recipient}:", "cyan"), flush=True)
- else:
- print(termcolor.colored("CoT:", "yellow"), flush=True)
+ # Print the last two tokens (start of assistant msg)
+ print(encoding.decode(tokens[-2:]), flush=True, end="")
+
+ parser = StreamableParser(encoding, role=Role.ASSISTANT)
+ field_created = False
+ current_output_text = ""
+ output_text_delta_buffer = ""
+
+ for predicted_token in generator.generate(tokens, encoding.stop_tokens_for_assistant_actions()):
+ parser.process(predicted_token)
+ if args.raw:
+ print(encoding.decode([predicted_token]), end="", flush=True)
+ continue
+
+ if parser.state == StreamState.EXPECT_START:
+ print("") # new line
+ field_created = False
+
+ if not parser.last_content_delta:
+ continue
+
+ if not field_created:
+ field_created = True
+ if parser.current_channel == "final":
+ print(termcolor.colored("Assistant:", "green"), flush=True)
+ elif parser.current_recipient is not None:
+ print(termcolor.colored(f"Tool call to {parser.current_recipient}:", "cyan"), flush=True)
+ else:
+ print(termcolor.colored("CoT:", "yellow"), flush=True)
+
+ should_send_output_text_delta = True
+ output_text_delta_buffer += parser.last_content_delta
+
+ if args.browser and browser_tool is not None:
+ updated_output_text, _annotations, has_partial_citations = browser_tool.normalize_citations(
+ current_output_text + output_text_delta_buffer
+ )
+ output_text_delta_buffer = updated_output_text[len(current_output_text) :]
+ if has_partial_citations:
+ should_send_output_text_delta = False
- should_send_output_text_delta = True
- output_text_delta_buffer += parser.last_content_delta
- if args.browser:
- updated_output_text, _annotations, has_partial_citations = browser_tool.normalize_citations(current_output_text + output_text_delta_buffer)
- output_text_delta_buffer = updated_output_text[len(current_output_text):]
- if has_partial_citations:
- should_send_output_text_delta = False
- if should_send_output_text_delta:
- print(output_text_delta_buffer, end="", flush=True)
- current_output_text += output_text_delta_buffer
- output_text_delta_buffer = ""
+ if should_send_output_text_delta:
+ print(output_text_delta_buffer, end="", flush=True)
+ current_output_text += output_text_delta_buffer
+ output_text_delta_buffer = ""
- messages += parser.messages
+ messages += parser.messages
+ except KeyboardInterrupt:
+ print("\n" + termcolor.colored("Interrupted. Bye!", "cyan"))
if __name__ == "__main__":
@@ -354,6 +383,12 @@ async def run_tool():
choices=["triton", "torch", "vllm"],
help="Inference backend",
)
+ parser.add_argument(
+ "--tp-size",
+ type=int,
+ default=2,
+ help="Tensor parallel size for vLLM backend",
+ )
args = parser.parse_args()
if int(os.environ.get("WORLD_SIZE", 1)) == 1:
diff --git a/gpt_oss/evals/__main__.py b/gpt_oss/evals/__main__.py
index bb34e2c..b89963b 100644
--- a/gpt_oss/evals/__main__.py
+++ b/gpt_oss/evals/__main__.py
@@ -3,14 +3,12 @@
from datetime import datetime
from . import report
+from .aime_eval import AIME25Eval
from .basic_eval import BasicEval
+from .chat_completions_sampler import (OPENAI_SYSTEM_MESSAGE_API,
+ ChatCompletionsSampler)
from .gpqa_eval import GPQAEval
-from .aime_eval import AIME25Eval
from .healthbench_eval import HealthBenchEval
-from .chat_completions_sampler import (
- OPENAI_SYSTEM_MESSAGE_API,
- ChatCompletionsSampler,
-)
from .responses_sampler import ResponsesSampler
@@ -62,16 +60,16 @@ def main():
default=1584,
help="Number of threads to run.",
)
- parser.add_argument(
- "--debug", action="store_true", help="Run in debug mode"
- )
+ parser.add_argument("--debug", action="store_true", help="Run in debug mode")
parser.add_argument(
"--examples", type=int, help="Number of examples to use (overrides default)"
)
args = parser.parse_args()
- sampler_cls = ResponsesSampler if args.sampler == "responses" else ChatCompletionsSampler
+ sampler_cls = (
+ ResponsesSampler if args.sampler == "responses" else ChatCompletionsSampler
+ )
models = {}
for model_name in args.model.split(","):
@@ -192,7 +190,8 @@ def get_evals(eval_name, debug_mode):
merge_metrics = []
for eval_model_name, result_filename in mergekey2resultpath.items():
try:
- result = json.load(open(result_filename, "r+"))
+ with open(result_filename, "r") as f:
+ result = json.load(f)
except Exception as e:
print(e, result_filename)
continue
diff --git a/gpt_oss/evals/abcd_grader.py b/gpt_oss/evals/abcd_grader.py
index 37088c8..cd2351f 100644
--- a/gpt_oss/evals/abcd_grader.py
+++ b/gpt_oss/evals/abcd_grader.py
@@ -1,23 +1,22 @@
import re
import sys
-
_PATTERNS = [
# 0)"**Answer:** A" or "*Answers* – B", i.e. markdown‐wrapped "Answer(s)" with an unwrapped letter.
re.compile(
- r'''(?ix) # case‐insensitive, ignore‐space
+ r"""(?ix) # case‐insensitive, ignore‐space
(?:\*{1,2}|_{1,2}) # leading *…* or _…_
Answer[s]? # Answer or Answers
\s*[:\-–]? # optional separator
(?:\*{1,2}|_{1,2}) # closing wrapper
\s* # optional space
([ABCD])\b # the actual letter
- ''',
- re.X
+ """,
+ re.X,
),
-
# 0.1)
- re.compile(r'''(?ix) # ignore case, allow verbose mode
+ re.compile(
+ r"""(?ix) # ignore case, allow verbose mode
^\s* # optional leading whitespace
(?:\*{1,2}|_{1,2})? # optional markdown wrapper
Answer:? # the word 'answer' with an optional colon
@@ -27,54 +26,56 @@
([ABCD]) # capture the letter
(?:\*{1,2}|_{1,2})? # optional markdown wrapper after letter
\s* # optional trailing whitespace, end of line
- ''', re.MULTILINE),
-
+ """,
+ re.MULTILINE,
+ ),
# 1) Answer: (C) or Answers: (B)
- re.compile(r'(?ix)\bAnswer[s]?\b\s*[:\-–]?\s*\(\s*([ABCD])\s*\)'),
-
+ re.compile(r"(?ix)\bAnswer[s]?\b\s*[:\-–]?\s*\(\s*([ABCD])\s*\)"),
# 2) Answer: C or Answers – D
- re.compile(r'(?ix)\bAnswer[s]?\b\s*[:\-–]?\s*([ABCD])\b'),
-
+ re.compile(r"(?ix)\bAnswer[s]?\b\s*[:\-–]?\s*([ABCD])\b"),
# 3) Option B or Choice: C
- re.compile(r'(?ix)\b(?:Option|Choice)\b\s*[:\-–]?\s*([ABCD])\b'),
-
+ re.compile(r"(?ix)\b(?:Option|Choice)\b\s*[:\-–]?\s*([ABCD])\b"),
# 7) LaTeX \boxed{...A...}, catches both \boxed{A} and
# \boxed{\text{A } 2.08\times10^{-6}\,\mathrm{m}} etc.
- re.compile(r'(?x)\\boxed\{[^}]*?([ABCD])[^}]*\}', re.MULTILINE),
-
+ re.compile(r"(?x)\\boxed\{[^}]*?([ABCD])[^}]*\}", re.MULTILINE),
# 7.5) LaTeX \boxed{\textbf{...C...}}
- re.compile(r'(?x)\\boxed\{[^}]*?\\textbf\{[^}]*?([ABCD])[^}]*\}[^}]*\}', re.MULTILINE),
-
+ re.compile(
+ r"(?x)\\boxed\{[^}]*?\\textbf\{[^}]*?([ABCD])[^}]*\}[^}]*\}", re.MULTILINE
+ ),
# 7.51) LaTeX \boxed{\text{...C...}}
- re.compile(r'(?x)\\boxed\{[^}]*?\\text\{[^}]*?([ABCD])[^}]*\}[^}]*\}', re.MULTILINE),
-
+ re.compile(
+ r"(?x)\\boxed\{[^}]*?\\text\{[^}]*?([ABCD])[^}]*\}[^}]*\}", re.MULTILINE
+ ),
# 4) bare singletons: (A) [B]
- re.compile(r'(?x)(? str | None:
m = pat.search(text)
if m:
letter = m.group(1).upper()
- if letter in 'ABCD':
+ if letter in "ABCD":
matches.append((prio, m, letter))
- matches.sort(key=lambda triple: (
- triple[0],
- len(triple[1].group(0))
- ))
+ matches.sort(key=lambda triple: (triple[0], len(triple[1].group(0))))
for _, match, letter in matches:
return letter
- return text.removeprefix('**')[:1]
+ return text.removeprefix("**")[:1]
def main():
if len(sys.argv) > 1:
# Process files
for fn in sys.argv[1:]:
- with open(fn, encoding='utf8') as fp:
+ with open(fn, encoding="utf8") as fp:
text = fp.read()
ans = extract_abcd(text)
print(f"{fn} ➜ {ans!r}")
@@ -118,4 +116,3 @@ def main():
if __name__ == "__main__":
main()
-
diff --git a/gpt_oss/evals/aime_eval.py b/gpt_oss/evals/aime_eval.py
index c6e9d64..0c5033b 100644
--- a/gpt_oss/evals/aime_eval.py
+++ b/gpt_oss/evals/aime_eval.py
@@ -1,64 +1,82 @@
"""
AIME 2025: https://huggingface.co/datasets/opencompass/AIME2025
"""
+
import random
import re
+
import pandas
-from . import report
+from . import report
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
-
AIME_TEMPLATE = """
{question}
Please reason step by step, and put your final answer within \\boxed{{}}.
"""
+
def format_aime_question(row):
return AIME_TEMPLATE.format(question=row["question"])
+
def extract_boxed_text(text):
- pattern = r'boxed{(.*?)}|framebox{(.*?)}'
+ pattern = r"boxed{(.*?)}|framebox{(.*?)}"
matches = re.findall(pattern, text, re.DOTALL)
if matches:
for match in matches[::-1]:
for group in match:
if group != "":
- return group.split(',')[-1].strip()
- pattern = r'\d+' # get the last integer if no pattern found
+ return group.split(",")[-1].strip()
+ pattern = r"\d+" # get the last integer if no pattern found
matches = re.findall(pattern, text, re.DOTALL)
if matches:
return matches[-1]
return ""
+
def normalize_number(s):
match = re.match(r"\d+", s) # match digits from the start
if not match:
return None
return match.group(0)
+
class AIME25Eval(Eval):
def __init__(
self,
n_repeats: int = 4,
- num_examples: int | None = None, # restrict to a subset of the data for debugging
+ num_examples: (
+ int | None
+ ) = None, # restrict to a subset of the data for debugging
n_threads: int = 1,
):
path1 = f"https://huggingface.co/datasets/opencompass/AIME2025/raw/main/aime2025-I.jsonl"
df1 = pandas.read_json(path1, lines=True)
path2 = f"https://huggingface.co/datasets/opencompass/AIME2025/raw/main/aime2025-II.jsonl"
df2 = pandas.read_json(path2, lines=True)
- examples = [row.to_dict() for _, row in df1.iterrows()] + [row.to_dict() for _, row in df2.iterrows()]
- examples = [{
- "question": row["question"],
- "answer": normalize_number(row["answer"]) if isinstance(row["answer"], str) else row["answer"],
- } for row in examples]
+ examples = [row.to_dict() for _, row in df1.iterrows()] + [
+ row.to_dict() for _, row in df2.iterrows()
+ ]
+ examples = [
+ {
+ "question": row["question"],
+ "answer": (
+ normalize_number(row["answer"])
+ if isinstance(row["answer"], str)
+ else row["answer"]
+ ),
+ }
+ for row in examples
+ ]
rng = random.Random(0)
if num_examples:
assert n_repeats == 1, "n_repeats only supported for num_examples = None"
examples = rng.sample(examples, num_examples)
examples = examples * n_repeats
- examples = [example | {"permutation": rng.sample(range(4), 4)} for example in examples]
+ examples = [
+ example | {"permutation": rng.sample(range(4), 4)} for example in examples
+ ]
self.examples = examples
self.n_repeats = n_repeats
self.n_threads = n_threads
@@ -66,16 +84,16 @@ def __init__(
def __call__(self, sampler: SamplerBase) -> EvalResult:
def fn(row: dict):
prompt_messages = [
- sampler._pack_message(
- content=format_aime_question(row), role="user"
- )
+ sampler._pack_message(content=format_aime_question(row), role="user")
]
sampler_response = sampler(prompt_messages)
response_text = sampler_response.response_text
- actual_queried_prompt_messages = sampler_response.actual_queried_message_list
+ actual_queried_prompt_messages = (
+ sampler_response.actual_queried_message_list
+ )
extracted_answer = extract_boxed_text(response_text)
correct_answer = int(row["answer"])
- try: # All AIME answers are integers, so we convert the extracted answer to an integer
+ try: # All AIME answers are integers, so we convert the extracted answer to an integer
extracted_answer = int(extracted_answer)
except (ValueError, TypeError):
extracted_answer = None
@@ -87,11 +105,17 @@ def fn(row: dict):
correct_answer=correct_answer,
extracted_answer=extracted_answer,
)
- convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")]
+ convo = actual_queried_prompt_messages + [
+ dict(content=response_text, role="assistant")
+ ]
return SingleEvalResult(
- html=html, score=score, convo=convo, metrics={"chars": len(response_text)}
+ html=html,
+ score=score,
+ convo=convo,
+ metrics={"chars": len(response_text)},
)
- results = report.map_with_progress(fn, self.examples, num_threads=self.n_threads)
+ results = report.map_with_progress(
+ fn, self.examples, num_threads=self.n_threads
+ )
return report.aggregate_results(results)
-
diff --git a/gpt_oss/evals/basic_eval.py b/gpt_oss/evals/basic_eval.py
index 7799530..89537c1 100644
--- a/gpt_oss/evals/basic_eval.py
+++ b/gpt_oss/evals/basic_eval.py
@@ -1,25 +1,32 @@
"""
Basic eval
"""
-from . import report
+from . import report
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
+
class BasicEval(Eval):
- def __init__(self,):
- self.examples = [{
- "question": "hi",
- "answer": "hi, how can i help?",
- }]
+ def __init__(
+ self,
+ ):
+ self.examples = [
+ {
+ "question": "hi",
+ "answer": "hi, how can i help?",
+ }
+ ]
def __call__(self, sampler: SamplerBase) -> EvalResult:
def fn(row: dict):
- sampler_response = sampler([
- sampler._pack_message(content=row["question"], role="user")
- ])
+ sampler_response = sampler(
+ [sampler._pack_message(content=row["question"], role="user")]
+ )
response_text = sampler_response.response_text
extracted_answer = response_text
- actual_queried_prompt_messages = sampler_response.actual_queried_message_list
+ actual_queried_prompt_messages = (
+ sampler_response.actual_queried_message_list
+ )
score = 1.0 if len(extracted_answer) > 0 else 0.0
html = report.jinja_env.from_string(report.HTML_JINJA).render(
prompt_messages=actual_queried_prompt_messages,
@@ -28,11 +35,15 @@ def fn(row: dict):
correct_answer=row["answer"],
extracted_answer=extracted_answer,
)
- convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")]
+ convo = actual_queried_prompt_messages + [
+ dict(content=response_text, role="assistant")
+ ]
return SingleEvalResult(
- html=html, score=score, convo=convo, metrics={"chars": len(response_text)}
+ html=html,
+ score=score,
+ convo=convo,
+ metrics={"chars": len(response_text)},
)
results = report.map_with_progress(fn, self.examples, num_threads=1)
return report.aggregate_results(results)
-
diff --git a/gpt_oss/evals/chat_completions_sampler.py b/gpt_oss/evals/chat_completions_sampler.py
index 29c1a0a..cfd72eb 100644
--- a/gpt_oss/evals/chat_completions_sampler.py
+++ b/gpt_oss/evals/chat_completions_sampler.py
@@ -2,11 +2,11 @@
from typing import Any
import openai
+import structlog
from openai import OpenAI
from .types import MessageList, SamplerBase, SamplerResponse
-
OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant."
OPENAI_SYSTEM_MESSAGE_CHATGPT = (
"You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture."
@@ -66,7 +66,9 @@ def __call__(self, message_list: MessageList) -> SamplerResponse:
choice = response.choices[0]
content = choice.message.content
if getattr(choice.message, "reasoning", None):
- message_list.append(self._pack_message("assistant", choice.message.reasoning))
+ message_list.append(
+ self._pack_message("assistant", choice.message.reasoning)
+ )
if not content:
raise ValueError("OpenAI API returned empty response; retrying")
@@ -76,18 +78,35 @@ def __call__(self, message_list: MessageList) -> SamplerResponse:
actual_queried_message_list=message_list,
)
except openai.BadRequestError as e:
- print("Bad Request Error", e)
+ logger = structlog.get_logger()
+ logger.error("Bad request error from OpenAI API", error=str(e))
return SamplerResponse(
response_text="No response (bad request).",
response_metadata={"usage": None},
actual_queried_message_list=message_list,
)
- except Exception as e:
- exception_backoff = 2 ** trial # exponential back off
- print(
- f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
- e,
+ except (
+ openai.RateLimitError,
+ openai.APITimeoutError,
+ openai.APIConnectionError,
+ ) as e:
+ exception_backoff = min(
+ 2**trial, 60
+ ) # exponential backoff with max 60s
+ logger = structlog.get_logger()
+ logger.warning(
+ "API error, retrying",
+ trial=trial,
+ backoff_seconds=exception_backoff,
+ error=str(e),
+ error_type=type(e).__name__,
)
time.sleep(exception_backoff)
trial += 1
+ except Exception as e:
+ logger = structlog.get_logger()
+ logger.error(
+ "Unexpected error during API call", error=str(e), trial=trial
+ )
+ raise # Re-raise unexpected errors
# unknown error shall throw exception
diff --git a/gpt_oss/evals/gpqa_eval.py b/gpt_oss/evals/gpqa_eval.py
index 1b12a43..42827f5 100644
--- a/gpt_oss/evals/gpqa_eval.py
+++ b/gpt_oss/evals/gpqa_eval.py
@@ -9,9 +9,8 @@
import pandas
from . import report
-from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
from .abcd_grader import extract_abcd
-
+from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
QUERY_TEMPLATE_MULTICHOICE = """
{Question}
@@ -34,7 +33,9 @@ def __init__(
self,
n_repeats: int = 8,
variant: str = "diamond",
- num_examples: int | None = None, # restrict to a subset of the data for debugging
+ num_examples: (
+ int | None
+ ) = None, # restrict to a subset of the data for debugging
debug: bool = False,
n_threads: int = 1,
):
@@ -44,15 +45,23 @@ def __init__(
rng = random.Random(0)
if debug:
- examples = [row.to_dict() for _, row in df.iterrows() if "ESPRESSO spectrograph, please" in row["Question"]]
+ examples = [
+ row.to_dict()
+ for _, row in df.iterrows()
+ if "ESPRESSO spectrograph, please" in row["Question"]
+ ]
else:
examples = [row.to_dict() for _, row in df.iterrows()]
if num_examples:
- assert n_repeats == 1, "n_repeats only supported for num_examples = None"
+ assert (
+ n_repeats == 1
+ ), "n_repeats only supported for num_examples = None"
examples = rng.sample(examples, num_examples)
examples = examples * n_repeats
- examples = [example | {"permutation": rng.sample(range(4), 4)} for example in examples]
+ examples = [
+ example | {"permutation": rng.sample(range(4), 4)} for example in examples
+ ]
self.examples = examples
self.n_repeats = n_repeats
self.n_threads = n_threads
@@ -69,7 +78,11 @@ def fn(row: dict):
correct_index = choices.index(row["Correct Answer"])
correct_answer = "ABCD"[correct_index]
choices_dict = dict(
- A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=row["Question"]
+ A=choices[0],
+ B=choices[1],
+ C=choices[2],
+ D=choices[3],
+ Question=row["Question"],
)
prompt_messages = [
sampler._pack_message(
@@ -78,7 +91,9 @@ def fn(row: dict):
]
sampler_response = sampler(prompt_messages)
response_text = sampler_response.response_text
- actual_queried_prompt_messages = sampler_response.actual_queried_message_list
+ actual_queried_prompt_messages = (
+ sampler_response.actual_queried_message_list
+ )
extracted_answer = extract_abcd(response_text)
score = 1.0 if extracted_answer == correct_answer else 0.0
html = report.jinja_env.from_string(report.HTML_JINJA).render(
@@ -88,12 +103,19 @@ def fn(row: dict):
correct_answer=correct_answer,
extracted_answer=extracted_answer,
)
- convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")]
+ convo = actual_queried_prompt_messages + [
+ dict(content=response_text, role="assistant")
+ ]
return SingleEvalResult(
- html=html, score=score, convo=convo, metrics={"chars": len(response_text)}
+ html=html,
+ score=score,
+ convo=convo,
+ metrics={"chars": len(response_text)},
)
- results = report.map_with_progress(fn, self.examples, num_threads=self.n_threads)
+ results = report.map_with_progress(
+ fn, self.examples, num_threads=self.n_threads
+ )
return report.aggregate_results(results)
@@ -122,4 +144,4 @@ def fn(row: dict):
print("--------------------------------")
pass_rate = passes / len(results["convos"])
- print(f"pass@1: {pass_rate}")
\ No newline at end of file
+ print(f"pass@1: {pass_rate}")
diff --git a/gpt_oss/evals/healthbench_eval.py b/gpt_oss/evals/healthbench_eval.py
index 09d184c..1467bad 100644
--- a/gpt_oss/evals/healthbench_eval.py
+++ b/gpt_oss/evals/healthbench_eval.py
@@ -26,10 +26,8 @@
import numpy as np
from . import report
-from .chat_completions_sampler import (
- OPENAI_SYSTEM_MESSAGE_API,
- ChatCompletionsSampler,
-)
+from .chat_completions_sampler import (OPENAI_SYSTEM_MESSAGE_API,
+ ChatCompletionsSampler)
from .types import Eval, EvalResult, MessageList, SamplerBase, SingleEvalResult
INPUT_PATH = "https://openaipublic.blob.core.windows.net/simple-evals/healthbench/2025-05-07-06-14-12_oss_eval.jsonl"
@@ -257,14 +255,12 @@ def __init__(
subset_name: Literal["hard", "consensus"] | None = None,
):
if run_reference_completions:
- assert physician_completions_mode is not None, (
- "physician_completions_mode must be provided if run_reference_completions is True"
- )
+ assert (
+ physician_completions_mode is not None
+ ), "physician_completions_mode must be provided if run_reference_completions is True"
assert PHYSICIAN_COMPLETION_MODES[physician_completions_mode][
"has_reference"
- ], (
- "physician_completions_mode must have reference completions if run_reference_completions is True"
- )
+ ], "physician_completions_mode must have reference completions if run_reference_completions is True"
if subset_name == "hard":
input_path = INPUT_PATH_HARD
@@ -284,9 +280,9 @@ def __init__(
# physician completions mode
self.physician_completions_mode = physician_completions_mode
if self.physician_completions_mode is not None:
- assert self.physician_completions_mode in PHYSICIAN_COMPLETION_MODES, (
- f"Invalid physician completions mode: {self.physician_completions_mode}; must be one of {PHYSICIAN_COMPLETION_MODES.keys()}"
- )
+ assert (
+ self.physician_completions_mode in PHYSICIAN_COMPLETION_MODES
+ ), f"Invalid physician completions mode: {self.physician_completions_mode}; must be one of {PHYSICIAN_COMPLETION_MODES.keys()}"
# subset to only the rows which have physician completions from that group
examples_matching_mode = [
example
diff --git a/gpt_oss/evals/report.py b/gpt_oss/evals/report.py
index 787dd1f..3d2a028 100644
--- a/gpt_oss/evals/report.py
+++ b/gpt_oss/evals/report.py
@@ -9,7 +9,6 @@
from .types import EvalResult, Message, SingleEvalResult
-
HTML_JINJA = """
Prompt conversation
{% for message in prompt_messages %}
diff --git a/gpt_oss/evals/responses_sampler.py b/gpt_oss/evals/responses_sampler.py
index fd9daef..4d76250 100644
--- a/gpt_oss/evals/responses_sampler.py
+++ b/gpt_oss/evals/responses_sampler.py
@@ -22,7 +22,7 @@ def __init__(
reasoning_effort: str | None = None,
base_url: str = "http://localhost:8000/v1",
):
- self.client = OpenAI(base_url=base_url, timeout=24*60*60)
+ self.client = OpenAI(base_url=base_url, timeout=24 * 60 * 60)
self.model = model
self.developer_message = developer_message
self.temperature = temperature
@@ -63,7 +63,11 @@ def __call__(self, message_list: MessageList) -> SamplerResponse:
for output in response.output:
if hasattr(output, "text"):
- message_list.append(self._pack_message(getattr(output, "role", "assistant"), output.text))
+ message_list.append(
+ self._pack_message(
+ getattr(output, "role", "assistant"), output.text
+ )
+ )
elif hasattr(output, "content"):
for c in output.content:
# c.text handled below
diff --git a/gpt_oss/evals/types.py b/gpt_oss/evals/types.py
index 8f2b42d..3887095 100644
--- a/gpt_oss/evals/types.py
+++ b/gpt_oss/evals/types.py
@@ -5,16 +5,17 @@
MessageList = list[Message]
-
@dataclass
class SamplerResponse:
"""
Response from a sampler.
"""
+
response_text: str
actual_queried_message_list: MessageList
response_metadata: dict[str, Any]
+
class SamplerBase:
"""
Base class for defining a sampling model, which can be evaluated,
@@ -22,7 +23,7 @@ class SamplerBase:
"""
def __call__(
- self,
+ self,
message_list: MessageList,
) -> SamplerResponse:
raise NotImplementedError
@@ -63,4 +64,3 @@ class Eval:
def __call__(self, sampler: SamplerBase) -> EvalResult:
raise NotImplementedError
-
diff --git a/gpt_oss/generate.py b/gpt_oss/generate.py
index c075580..728093e 100644
--- a/gpt_oss/generate.py
+++ b/gpt_oss/generate.py
@@ -1,6 +1,7 @@
# Model parallel inference
# Note: This script is for demonstration purposes only. It is not designed for production use.
# See gpt_oss.chat for a more complete example with the Harmony parser.
+# Example:
# torchrun --nproc-per-node=4 -m gpt_oss.generate -p "why did the chicken cross the road?" model/
import argparse
@@ -11,30 +12,47 @@
def main(args):
match args.backend:
case "torch":
- from gpt_oss.torch.utils import init_distributed
from gpt_oss.torch.model import TokenGenerator as TorchGenerator
+ from gpt_oss.torch.utils import init_distributed
+
device = init_distributed()
generator = TorchGenerator(args.checkpoint, device=device)
+
case "triton":
from gpt_oss.torch.utils import init_distributed
from gpt_oss.triton.model import TokenGenerator as TritonGenerator
+
device = init_distributed()
- generator = TritonGenerator(args.checkpoint, context=args.context_length, device=device)
+ generator = TritonGenerator(
+ args.checkpoint,
+ context=args.context_length,
+ device=device
+ )
+
case "vllm":
from gpt_oss.vllm.token_generator import TokenGenerator as VLLMGenerator
- generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=args.tensor_parallel_size)
+ generator = VLLMGenerator(
+ args.checkpoint,
+ tensor_parallel_size=args.tensor_parallel_size
+ )
+
case _:
raise ValueError(f"Invalid backend: {args.backend}")
tokenizer = get_tokenizer()
tokens = tokenizer.encode(args.prompt)
max_tokens = None if args.limit == 0 else args.limit
- for token, logprob in generator.generate(tokens, stop_tokens=[tokenizer.eot_token], temperature=args.temperature, max_tokens=max_tokens, return_logprobs=True):
+
+ for token, logprob in generator.generate(
+ tokens,
+ stop_tokens=[tokenizer.eot_token],
+ temperature=args.temperature,
+ max_tokens=max_tokens,
+ return_logprobs=True,
+ ):
tokens.append(token)
token_text = tokenizer.decode([token])
- print(
- f"Generated token: {repr(token_text)}, logprob: {logprob}"
- )
+ print(f"Generated token: {repr(token_text)}, logprob: {logprob}")
if __name__ == "__main__":
@@ -91,5 +109,4 @@ def main(args):
help="Context length for Triton backend",
)
args = parser.parse_args()
-
main(args)
diff --git a/gpt_oss/logging_config.py b/gpt_oss/logging_config.py
new file mode 100644
index 0000000..42a7a3c
--- /dev/null
+++ b/gpt_oss/logging_config.py
@@ -0,0 +1,90 @@
+"""
+Logging configuration for gpt-oss package.
+Provides structured logging with appropriate levels and formatting.
+"""
+
+import logging
+import sys
+from typing import Any, Dict, Optional
+
+import structlog
+
+# Configure structlog processors
+structlog.configure(
+ processors=[
+ # Add log level to log entries
+ structlog.stdlib.add_log_level,
+ # Add a timestamp in ISO format
+ structlog.processors.TimeStamper(fmt="iso"),
+ # If the "stack_info" key in the event dict is true, remove it and
+ # render the current stack trace in the "stack" key.
+ structlog.processors.StackInfoRenderer(),
+ # Format the exception only once, even if there are multiple loggers
+ structlog.processors.format_exc_info,
+ # Render in JSON for production, or pretty print for development
+ (
+ structlog.dev.ConsoleRenderer()
+ if sys.stderr.isatty()
+ else structlog.processors.JSONRenderer()
+ ),
+ ],
+ # Our `wrapper_class` is used for passing metadata when binding
+ wrapper_class=structlog.stdlib.BoundLogger,
+ # `logger_factory` is used to create wrapped loggers that are used for OUTPUT.
+ logger_factory=structlog.stdlib.LoggerFactory(),
+ # Cache the logger for better performance
+ cache_logger_on_first_use=True,
+)
+
+
+def configure_logging(
+ level: str = "INFO", format_json: bool = False, component: Optional[str] = None
+) -> None:
+ """
+ Configure logging for the application.
+
+ Args:
+ level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
+ format_json: Whether to format logs as JSON
+ component: Optional component name to include in logs
+ """
+ logging.basicConfig(
+ format="%(message)s",
+ stream=sys.stderr,
+ level=getattr(logging, level.upper()),
+ )
+
+ # Update structlog processors based on configuration
+ processors = [
+ structlog.stdlib.add_log_level,
+ structlog.processors.TimeStamper(fmt="iso"),
+ structlog.processors.StackInfoRenderer(),
+ structlog.processors.format_exc_info,
+ ]
+
+ if component:
+ processors.append(structlog.processors.CallsiteParameterAdder())
+
+ if format_json:
+ processors.append(structlog.processors.JSONRenderer())
+ else:
+ processors.append(structlog.dev.ConsoleRenderer())
+
+ structlog.configure(processors=processors)
+
+
+def get_logger(name: str, **initial_values: Any) -> structlog.stdlib.BoundLogger:
+ """
+ Get a structured logger with optional initial values.
+
+ Args:
+ name: Logger name (typically __name__)
+ **initial_values: Key-value pairs to include in all log messages
+
+ Returns:
+ Structured logger instance
+ """
+ logger = structlog.get_logger(name)
+ if initial_values:
+ logger = logger.bind(**initial_values)
+ return logger
diff --git a/gpt_oss/metal/examples/chat.py b/gpt_oss/metal/examples/chat.py
index f29cec3..c2d8648 100755
--- a/gpt_oss/metal/examples/chat.py
+++ b/gpt_oss/metal/examples/chat.py
@@ -2,10 +2,9 @@
import argparse
import sys
-
from datetime import date
-from gpt_oss.metal import Context, Model
+from gpt_oss.metal import Context, Model
DEFAULT_PROMPT = f"""You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2024-06
@@ -16,8 +15,16 @@
# Valid channels: analysis, final. Channel must be included for every message."""
-parser = argparse.ArgumentParser(description="Chat with gpt-oss", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
-parser.add_argument("model", metavar="PATH", type=str, help="Path to gpt-oss model in Metal inference format")
+parser = argparse.ArgumentParser(
+ description="Chat with gpt-oss",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+)
+parser.add_argument(
+ "model",
+ metavar="PATH",
+ type=str,
+ help="Path to gpt-oss model in Metal inference format",
+)
parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, help="System prompt")
parser.add_argument(
"--context-length", type=int, default=0, help="The maximum context length"
@@ -25,9 +32,7 @@
parser.add_argument(
"--temperature", type=float, default=1.0, help="Sampling temperature"
)
-parser.add_argument(
- "--seed", type=int, default=0, help="Sampling seed"
-)
+parser.add_argument("--seed", type=int, default=0, help="Sampling seed")
GREY = "\33[90m"
diff --git a/gpt_oss/metal/examples/generate.py b/gpt_oss/metal/examples/generate.py
index 3b78199..9e63848 100644
--- a/gpt_oss/metal/examples/generate.py
+++ b/gpt_oss/metal/examples/generate.py
@@ -5,12 +5,20 @@
from gpt_oss.metal import Context, Model
-
-parser = argparse.ArgumentParser(description='Chat with gpt-oss', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
-parser.add_argument('model', metavar='PATH', type=str, help='Path to gpt-oss checkpoint')
-parser.add_argument('-p', '--prompt', type=str, required=True, help='Prompt')
-parser.add_argument('-l', '--limit', type=int, default=100, help='Number of tokens to generate')
-parser.add_argument('--context-length', type=int, default=0, help='The maximum context length')
+parser = argparse.ArgumentParser(
+ description="Chat with gpt-oss",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+)
+parser.add_argument(
+ "model", metavar="PATH", type=str, help="Path to gpt-oss checkpoint"
+)
+parser.add_argument("-p", "--prompt", type=str, required=True, help="Prompt")
+parser.add_argument(
+ "-l", "--limit", type=int, default=100, help="Number of tokens to generate"
+)
+parser.add_argument(
+ "--context-length", type=int, default=0, help="The maximum context length"
+)
def main(args):
@@ -27,8 +35,8 @@ def main(args):
while context.num_tokens - prompt_tokens < options.limit:
token = context.sample()
context.append(token)
- print(str(tokenizer.decode(token), encoding="utf-8"), end='', flush=True)
+ print(str(tokenizer.decode(token), encoding="utf-8"), end="", flush=True)
-if __name__ == '__main__':
+if __name__ == "__main__":
main(sys.argv[1:])
diff --git a/gpt_oss/metal/scripts/create-local-model.py b/gpt_oss/metal/scripts/create-local-model.py
index c0de8bd..12bb5ff 100644
--- a/gpt_oss/metal/scripts/create-local-model.py
+++ b/gpt_oss/metal/scripts/create-local-model.py
@@ -1,22 +1,37 @@
import argparse
-import os
-import math
-import sys
-import json
import itertools
+import json
+import math
+import os
import struct
+import sys
from uuid import UUID
import tiktoken
-
import torch
+from openai_harmony import HarmonyEncodingName, load_harmony_encoding
from safetensors import safe_open
from tqdm import tqdm
-from openai_harmony import load_harmony_encoding, HarmonyEncodingName
-parser = argparse.ArgumentParser(prog='check-mxfp4-weights.py', description='Validated MXFP4 weights')
-parser.add_argument('-s', '--src', metavar='DIR', type=str, required=True, help='Path to the input checkpoint directory')
-parser.add_argument('-d', '--dst', metavar='FILE', type=str, required=True, help='Path to the output model file')
+parser = argparse.ArgumentParser(
+ prog="check-mxfp4-weights.py", description="Validated MXFP4 weights"
+)
+parser.add_argument(
+ "-s",
+ "--src",
+ metavar="DIR",
+ type=str,
+ required=True,
+ help="Path to the input checkpoint directory",
+)
+parser.add_argument(
+ "-d",
+ "--dst",
+ metavar="FILE",
+ type=str,
+ required=True,
+ help="Path to the output model file",
+)
o200k_base = tiktoken.get_encoding("o200k_base")
@@ -44,21 +59,36 @@
"<|reversed200011|>": 200011, # unused
"<|call|>": 200012,
"<|refusal|>": 200013,
- }
+ },
)
-FILE_MAGIC = struct.pack('ccccccccccccI', b'G', b'P', b'T', b'-', b'O', b'S', b'S', b' ', b'v', b'1', b'.', b'0', 0)
+FILE_MAGIC = struct.pack(
+ "ccccccccccccI",
+ b"G",
+ b"P",
+ b"T",
+ b"-",
+ b"O",
+ b"S",
+ b"S",
+ b" ",
+ b"v",
+ b"1",
+ b".",
+ b"0",
+ 0,
+)
SPECIAL_TOKEN_UUID = {
- '<|start|>': UUID('55a77c2f-8a01-4c54-8ac2-313bfc7e208d').bytes,
- '<|message|>': UUID('16e40431-f47f-4b22-b59b-8b278fc30a54').bytes,
- '<|end|>': UUID('fcac2f6d-4705-4f6b-b228-642accac7238').bytes,
- '<|return|>': UUID('f799ff69-1992-43c4-a3d8-d831f475dc75').bytes,
- '<|refusal|>': UUID('e15ba702-28c4-4292-ab8f-ffa434709128').bytes,
- '<|constrain|>': UUID('c0bb14c7-6022-49da-ad08-792d67e8b470').bytes,
- '<|channel|>': UUID('fd3dda11-c8ab-4033-876e-d93deb172c93').bytes,
- '<|call|>': UUID('1220f796-e388-4de5-b487-fe2eb5fe03c0').bytes,
- '<|untrusted|>': UUID('07d7da55-b346-4cff-8b37-7cefacf8a3e8').bytes,
- '<|end_untrusted|>': UUID('f265bd9c-c717-469e-a447-920687d65d90').bytes,
+ "<|start|>": UUID("55a77c2f-8a01-4c54-8ac2-313bfc7e208d").bytes,
+ "<|message|>": UUID("16e40431-f47f-4b22-b59b-8b278fc30a54").bytes,
+ "<|end|>": UUID("fcac2f6d-4705-4f6b-b228-642accac7238").bytes,
+ "<|return|>": UUID("f799ff69-1992-43c4-a3d8-d831f475dc75").bytes,
+ "<|refusal|>": UUID("e15ba702-28c4-4292-ab8f-ffa434709128").bytes,
+ "<|constrain|>": UUID("c0bb14c7-6022-49da-ad08-792d67e8b470").bytes,
+ "<|channel|>": UUID("fd3dda11-c8ab-4033-876e-d93deb172c93").bytes,
+ "<|call|>": UUID("1220f796-e388-4de5-b487-fe2eb5fe03c0").bytes,
+ "<|untrusted|>": UUID("07d7da55-b346-4cff-8b37-7cefacf8a3e8").bytes,
+ "<|end_untrusted|>": UUID("f265bd9c-c717-469e-a447-920687d65d90").bytes,
}
INCLUDE_SPECIAL_TOKENS = [
@@ -74,62 +104,65 @@
"<|end_untrusted|>",
]
-GPTOSS_MODEL_UUID = UUID('df52dc86-1789-4ed0-a295-66f10508145b').bytes
-APPLE_GPU_LAYOUT_UUID = UUID('229177a8-5775-4268-bfd8-d588b351c56d').bytes
-TIKTOKEN_TOKENIZER_UUID = UUID('7401aded-2a95-40cb-b782-9ccebaafe72b').bytes
+GPTOSS_MODEL_UUID = UUID("df52dc86-1789-4ed0-a295-66f10508145b").bytes
+APPLE_GPU_LAYOUT_UUID = UUID("229177a8-5775-4268-bfd8-d588b351c56d").bytes
+TIKTOKEN_TOKENIZER_UUID = UUID("7401aded-2a95-40cb-b782-9ccebaafe72b").bytes
UE8_OFFSET = 14 # bias to MXFP4 block scales
+
def write_file_header(f):
f.write(FILE_MAGIC)
-def write_tokenizer_header(f,
- num_special_tokens: int,
- num_text_tokens: int,
- regex_size: int,
- tokens_size: int):
+
+def write_tokenizer_header(
+ f, num_special_tokens: int, num_text_tokens: int, regex_size: int, tokens_size: int
+):
f.write(TIKTOKEN_TOKENIZER_UUID)
- f.write(struct.pack(' 0
- dst.write(struct.pack(' ReasoningE
def is_not_builtin_tool(recipient: str) -> bool:
- return not recipient.startswith("browser.") and not recipient == "python" and not recipient == "assistant"
+ return (
+ not recipient.startswith("browser.")
+ and recipient != "python"
+ and recipient != "assistant"
+ )
+
def create_api_server(
- infer_next_token: Callable[[list[int], float], int], encoding: HarmonyEncoding
+ infer_next_token: Callable[[List[int], float], int],
+ encoding: HarmonyEncoding,
) -> FastAPI:
app = FastAPI()
- responses_store: dict[str, tuple[ResponsesRequest, ResponseObject]] = {}
+ responses_store: Dict[str, Tuple[ResponsesRequest, ResponseObject]] = {}
def generate_response(
- input_tokens: list[int],
- output_tokens: list[int],
+ input_tokens: List[int],
+ output_tokens: List[int],
request_body: ResponsesRequest,
debug_mode: bool = False,
- function_call_ids: Optional[list[tuple[str, str]]] = None,
+ function_call_ids: Optional[List[Tuple[str, str]]] = None,
response_id: Optional[str] = None,
previous_response_id: Optional[str] = None,
browser_tool: Optional[SimpleBrowserTool] = None,
- browser_call_ids: Optional[list[str]] = None,
+ browser_call_ids: Optional[List[str]] = None,
) -> ResponseObject:
- output = []
- error = None
+ output: List[Item] = []
+ error: Optional[Error] = None
+
if len(output_tokens) > 0:
- if debug_mode:
- try:
- entries = encoding.parse_messages_from_completion_tokens(
- output_tokens, Role.ASSISTANT
- )
- except Exception as e:
- print(f"Error parsing tokens: {e}")
- error = Error(
- code="invalid_function_call",
- message=f"{e}",
- )
- entries = []
- else:
+ # Parse assistant messages from completion tokens
+ try:
entries = encoding.parse_messages_from_completion_tokens(
output_tokens, Role.ASSISTANT
)
+ except Exception as e:
+ # In debug mode, surface a structured error
+ if debug_mode:
+ print(f"Error parsing tokens: {e}")
+ error = Error(code="invalid_function_call", message=str(e))
+ entries = []
+ else:
+ raise
fc_index = 0
browser_tool_index = 0
for entry in entries:
entry_dict = entry.to_dict()
- if len(entry_dict.get("recipient", "")) > 0 and is_not_builtin_tool(entry_dict["recipient"]):
+
+ # Function tool calls (non-builtin)
+ if len(entry_dict.get("recipient", "")) > 0 and is_not_builtin_tool(
+ entry_dict["recipient"]
+ ):
call = entry_dict["content"][0]
arguments = call["text"]
name = entry_dict["recipient"]
-
if name.startswith("functions."):
name = name[len("functions.") :]
+
if function_call_ids and fc_index < len(function_call_ids):
fc_id, call_id = function_call_ids[fc_index]
else:
@@ -130,6 +138,7 @@ def generate_response(
f"call_{uuid.uuid4().hex}",
)
fc_index += 1
+
output.append(
FunctionCallItem(
type="function_call",
@@ -139,32 +148,34 @@ def generate_response(
call_id=call_id,
)
)
- elif len(entry_dict.get("recipient", "")) > 0 and entry_dict["recipient"].startswith("browser.") and browser_tool is not None:
- # Mirror event-based creation of WebSearchCallItems when the browser tool is invoked
+
+ # Browser tool calls
+ elif (
+ len(entry_dict.get("recipient", "")) > 0
+ and entry_dict["recipient"].startswith("browser.")
+ and browser_tool is not None
+ ):
name = entry_dict["recipient"]
call = entry_dict["content"][0]
arguments = call["text"]
- function_name = name[len("browser."):]
+ function_name = name[len("browser.") :]
- # Reconstruct a Message for argument parsing
tool_msg = (
Message.from_role_and_content(Role.ASSISTANT, arguments)
.with_recipient(name)
.with_channel("analysis")
)
- action = None
+ action: Optional[WebSearchActionSearch | WebSearchActionOpenPage | WebSearchActionFind] = None
try:
parsed_args = browser_tool.process_arguments(tool_msg)
if function_name == "search":
action = WebSearchActionSearch(
- type="search",
- query=parsed_args["query"],
+ type="search", query=parsed_args["query"]
)
elif function_name == "open":
action = WebSearchActionOpenPage(
- type="open_page",
- url=parsed_args["url"],
+ type="open_page", url=parsed_args["url"]
)
elif function_name == "find":
action = WebSearchActionFind(
@@ -184,16 +195,18 @@ def generate_response(
browser_tool_index += 1
output.append(
WebSearchCallItem(
- type="web_search_call",
- id=web_search_call_id,
- action=action,
+ type="web_search_call", id=web_search_call_id, action=action
)
)
- elif entry_dict["channel"] == "final":
- content = []
- for content_entry in entry_dict["content"]:
+
+ # Final channel message
+ elif entry_dict.get("channel") == "final":
+ content: List[TextContentItem] = []
+ for content_entry in entry_dict["content"]:
if browser_tool:
- text_content, annotation_entries, _has_partial_citations = browser_tool.normalize_citations(content_entry["text"])
+ text_content, annotation_entries, _has_partial = (
+ browser_tool.normalize_citations(content_entry["text"])
+ )
annotations = [UrlCitation(**a) for a in annotation_entries]
else:
text_content = content_entry["text"]
@@ -201,9 +214,7 @@ def generate_response(
content.append(
TextContentItem(
- type="output_text",
- text=text_content,
- annotations=annotations,
+ type="output_text", text=text_content, annotations=annotations
)
)
@@ -215,24 +226,18 @@ def generate_response(
status="completed",
)
)
- elif entry_dict["channel"] == "analysis":
- summary = []
+
+ # Analysis channel (reasoning)
+ elif entry_dict.get("channel") == "analysis":
content = [
ReasoningTextContentItem(
- type="reasoning_text",
- text=entry["text"],
+ type="reasoning_text", text=entry["text"]
)
for entry in entry_dict["content"]
]
output.append(
- ReasoningItem(
- type="reasoning",
- summary=summary,
- content=content,
- )
+ ReasoningItem(type="reasoning", summary=[], content=content)
)
- else:
- output = []
usage = (
Usage(
@@ -246,18 +251,15 @@ def generate_response(
try:
debug_str = encoding.decode_utf8(input_tokens + output_tokens)
- except Exception:
- debug_str = input_tokens + output_tokens
- try:
debug_input_str = encoding.decode_utf8(input_tokens)
- except Exception:
- debug_input_str = input_tokens
- try:
debug_output_str = encoding.decode_utf8(output_tokens)
except Exception:
- debug_output_str = output_tokens
+ # Fallback to raw tokens when decode fails
+ debug_str = input_tokens + output_tokens # type: ignore[assignment]
+ debug_input_str = input_tokens # type: ignore[assignment]
+ debug_output_str = output_tokens # type: ignore[assignment]
- metadata = (
+ metadata: Dict[str, Any] = (
{
"__debug": debug_str,
"__debug_input": debug_input_str,
@@ -281,18 +283,17 @@ def generate_response(
)
class StreamResponsesEvents:
- initial_tokens: list[int]
- tokens: list[int]
- output_tokens: list[int]
+ initial_tokens: List[int]
+ tokens: List[int]
+ output_tokens: List[int]
output_text: str
request_body: ResponsesRequest
request: Request
sequence_number: int
-
def __init__(
self,
- initial_tokens,
+ initial_tokens: List[int],
request_body: ResponsesRequest,
as_sse: bool = False,
request: Optional[Request] = None,
@@ -301,7 +302,7 @@ def __init__(
Callable[[str, ResponsesRequest, ResponseObject], None]
] = None,
browser_tool: Optional[SimpleBrowserTool] = None,
- ):
+ ) -> None:
self.initial_tokens = initial_tokens
self.tokens = initial_tokens.copy()
self.output_tokens = []
@@ -309,10 +310,7 @@ def __init__(
self.request_body = request_body
self.parser = StreamableParser(encoding, role=Role.ASSISTANT)
self.as_sse = as_sse
- self.debug_mode = request_body.metadata.get(
- "__debug", False
- ) # we use this for demo purposes
- # Set temperature for this stream, fallback to DEFAULT_TEMPERATURE if not set
+ self.debug_mode = request_body.metadata.get("__debug", False)
self.temperature = (
request_body.temperature
if request_body.temperature is not None
@@ -320,13 +318,13 @@ def __init__(
)
self.request = request
self.sequence_number = 0
- self.function_call_ids: list[tuple[str, str]] = []
+ self.function_call_ids: List[Tuple[str, str]] = []
self.response_id = response_id
self.store_callback = store_callback
self.new_request = True
self.browser_tool = browser_tool
self.use_browser_tool = browser_tool is not None
- self.browser_call_ids: list[str] = []
+ self.browser_call_ids: List[str] = []
def _send_event(self, event: ResponseEvent):
event.sequence_number = self.sequence_number
@@ -348,46 +346,39 @@ async def run(self):
previous_response_id=self.request_body.previous_response_id,
)
initial_response.status = "in_progress"
+
yield self._send_event(
- ResponseCreatedEvent(
- type="response.created",
- response=initial_response,
- )
+ ResponseCreatedEvent(type="response.created", response=initial_response)
)
yield self._send_event(
ResponseInProgressEvent(
- type="response.in_progress",
- response=initial_response,
+ type="response.in_progress", response=initial_response
)
)
- current_content_index = (
- 0 # for this implementation we will always have one content item only
- )
+ current_content_index = 0
current_output_index = -1
sent_output_item_added = False
- # we use this if the model outputs a citation to buffer until completed
- output_delta_buffer = ""
- # we use this to track the current output text content for things like providing the right indices in citations
- current_output_text_content = ""
- current_annotations = []
+ output_delta_buffer = ""
+ current_output_text_content = ""
+ current_annotations: List[Dict[str, Any]] = []
while True:
- # Check for client disconnect
if self.request is not None and await self.request.is_disconnected():
print("Client disconnected, stopping token generation.")
break
+
next_tok = infer_next_token(
- self.tokens,
- temperature=self.temperature,
- new_request=self.new_request,
+ self.tokens, temperature=self.temperature, new_request=self.new_request
)
self.new_request = False
self.tokens.append(next_tok)
+
try:
self.parser.process(next_tok)
- except Exception as e:
+ except Exception:
+ # Parser may raise before enough tokens; safe to continue
pass
if self.parser.state == StreamState.EXPECT_START:
@@ -396,12 +387,10 @@ async def run(self):
if len(self.parser.messages) > 0:
previous_item = self.parser.messages[-1]
+
if previous_item.recipient is not None:
recipient = previous_item.recipient
- if (
- not recipient.startswith("browser.")
- and not recipient == "python"
- ):
+ if not recipient.startswith("browser.") and recipient != "python":
fc_id = f"fc_{uuid.uuid4().hex}"
call_id = f"call_{uuid.uuid4().hex}"
self.function_call_ids.append((fc_id, call_id))
@@ -412,13 +401,9 @@ async def run(self):
item=FunctionCallItem(
type="function_call",
name=(
- previous_item.recipient[
- len("functions.") :
- ]
- if previous_item.recipient.startswith(
- "functions."
- )
- else previous_item.recipient
+ recipient[len("functions.") :]
+ if recipient.startswith("functions.")
+ else recipient
),
arguments=previous_item.content[0].text,
id=fc_id,
@@ -426,6 +411,7 @@ async def run(self):
),
)
)
+
if previous_item.channel == "analysis":
yield self._send_event(
ResponseReasoningTextDone(
@@ -462,17 +448,19 @@ async def run(self):
),
)
)
+
if previous_item.channel == "final":
annotations = [UrlCitation(**a) for a in current_annotations]
if browser_tool:
- normalized_text, _annotations, _has_partial_citations = browser_tool.normalize_citations(previous_item.content[0].text)
+ normalized_text, _annos, _has_partial = browser_tool.normalize_citations(
+ previous_item.content[0].text
+ )
else:
normalized_text = previous_item.content[0].text
annotations = []
+
text_content = TextContentItem(
- type="output_text",
- text=normalized_text,
- annotations=annotations,
+ type="output_text", text=normalized_text, annotations=annotations
)
yield self._send_event(
ResponseOutputTextDone(
@@ -494,16 +482,13 @@ async def run(self):
ResponseOutputItemDone(
type="response.output_item.done",
output_index=current_output_index,
- item=Item(
- type="message",
- role="assistant",
- content=[text_content],
- ),
+ item=Item(type="message", role="assistant", content=[text_content]),
)
)
current_annotations = []
current_output_text_content = ""
+ # Streaming assistant output (final channel)
if (
self.parser.last_content_delta
and self.parser.current_channel == "final"
@@ -529,16 +514,17 @@ async def run(self):
output_delta_buffer += self.parser.last_content_delta
should_send_output_text_delta = True
+
if browser_tool:
- # we normalize on the full current text to get the right indices in citations
- updated_output_text, annotations, has_partial_citations = browser_tool.normalize_citations(current_output_text_content + output_delta_buffer)
- # remove the current text to get back the delta but now normalized
- output_delta_buffer = updated_output_text[len(current_output_text_content):]
-
- # Filter annotations to only include those whose start_index is not already present in current_annotations
- # this is to avoid sending duplicate annotations as multiple annotations can't be in the same place
+ updated_output_text, annotations, has_partial = browser_tool.normalize_citations(
+ current_output_text_content + output_delta_buffer
+ )
+ output_delta_buffer = updated_output_text[len(current_output_text_content) :]
+
existing_start_indices = {a["start_index"] for a in current_annotations}
- new_annotations = [a for a in annotations if a["start_index"] not in existing_start_indices]
+ new_annotations = [
+ a for a in annotations if a["start_index"] not in existing_start_indices
+ ]
for a in new_annotations:
current_annotations.append(a)
citation = UrlCitation(**a)
@@ -552,10 +538,9 @@ async def run(self):
)
)
- if has_partial_citations:
+ if has_partial:
should_send_output_text_delta = False
-
if should_send_output_text_delta:
yield self._send_event(
ResponseOutputTextDelta(
@@ -568,6 +553,7 @@ async def run(self):
current_output_text_content += output_delta_buffer
output_delta_buffer = ""
+ # Streaming reasoning (analysis channel)
if (
self.parser.last_content_delta
and self.parser.current_channel == "analysis"
@@ -579,9 +565,7 @@ async def run(self):
ResponseOutputItemAdded(
type="response.output_item.added",
output_index=current_output_index,
- item=ReasoningItem(
- type="reasoning", summary=[], content=[]
- ),
+ item=ReasoningItem(type="reasoning", summary=[], content=[]),
)
)
yield self._send_event(
@@ -592,6 +576,7 @@ async def run(self):
part=ReasoningTextContentItem(type="reasoning_text", text=""),
)
)
+
yield self._send_event(
ResponseReasoningTextDelta(
type="response.reasoning_text.delta",
@@ -602,59 +587,54 @@ async def run(self):
)
try:
- # purely for debugging purposes
output_token_text = encoding.decode_utf8([next_tok])
self.output_text += output_token_text
print(output_token_text, end="", flush=True)
-
except RuntimeError:
pass
if next_tok in encoding.stop_tokens_for_assistant_actions():
if len(self.parser.messages) > 0:
last_message = self.parser.messages[-1]
+ # Handle browser tool invocation mid-stream
if (
self.use_browser_tool
and last_message.recipient is not None
and last_message.recipient.startswith("browser.")
):
- function_name = last_message.recipient[len("browser."):]
+ function_name = last_message.recipient[len("browser.") :]
action = None
parsed_args = browser_tool.process_arguments(last_message)
if function_name == "search":
- action = WebSearchActionSearch(
- type="search",
- query=parsed_args["query"],
- )
+ action = WebSearchActionSearch(type="search", query=parsed_args["query"])
elif function_name == "open":
- action = WebSearchActionOpenPage(
- type="open_page",
- url=parsed_args["url"] if "url" in parsed_args else None,
- )
+ action = WebSearchActionOpenPage(type="open_page", url=parsed_args.get("url"))
elif function_name == "find":
action = WebSearchActionFind(
type="find",
pattern=parsed_args["pattern"],
- url=parsed_args["url"] if "url" in parsed_args else None,
+ url=parsed_args.get("url"),
)
if action is not None:
web_search_call_id = f"ws_{uuid.uuid4().hex}"
self.browser_call_ids.append(web_search_call_id)
- yield self._send_event(ResponseOutputItemAdded(
- type="response.output_item.added",
- output_index=current_output_index,
- item=WebSearchCallItem(
- type="web_search_call",
- id=web_search_call_id,
- action=action,
- ),
- ))
+ yield self._send_event(
+ ResponseOutputItemAdded(
+ type="response.output_item.added",
+ output_index=current_output_index,
+ item=WebSearchCallItem(
+ type="web_search_call",
+ id=web_search_call_id,
+ action=action,
+ ),
+ )
+ )
yield self._send_event(
ResponseWebSearchCallInProgress(
type="response.web_search_call.in_progress",
output_index=current_output_index,
- id=web_search_call_id
+ id=web_search_call_id,
)
)
@@ -676,10 +656,12 @@ async def run_tool():
new_tokens = encoding.render_conversation_for_completion(
Conversation.from_messages(result), Role.ASSISTANT
)
-
+
print(encoding.decode_utf8(new_tokens))
self.output_tokens.append(next_tok)
- self.tokens.append(encoding.encode('<|end|>', allowed_special="all")[0])
+ self.tokens.append(
+ encoding.encode("<|end|>", allowed_special="all")[0]
+ )
for token in new_tokens:
self.parser.process(token)
@@ -693,29 +675,29 @@ async def run_tool():
id=web_search_call_id,
)
)
- yield self._send_event(ResponseOutputItemDone(
- type="response.output_item.done",
- output_index=current_output_index,
- item=WebSearchCallItem(
- type="web_search_call",
- id=web_search_call_id,
- action=action,
- ),
- ))
+ yield self._send_event(
+ ResponseOutputItemDone(
+ type="response.output_item.done",
+ output_index=current_output_index,
+ item=WebSearchCallItem(
+ type="web_search_call",
+ id=web_search_call_id,
+ action=action,
+ ),
+ )
+ )
current_output_index += 1
self.new_request = True
-
continue
-
else:
break
else:
raise ValueError("No messages to process")
+
if len(self.output_tokens) >= self.request_body.max_output_tokens:
break
- # Adding in the end if we know we are not done
self.output_tokens.append(next_tok)
if self.request is None or not await self.request.is_disconnected():
@@ -733,28 +715,20 @@ async def run_tool():
if self.store_callback and self.request_body.store:
self.store_callback(self.response_id, self.request_body, response)
yield self._send_event(
- ResponseCompletedEvent(
- type="response.completed",
- response=response,
- )
+ ResponseCompletedEvent(type="response.completed", response=response)
)
@app.post("/v1/responses", response_model=ResponseObject)
async def generate(body: ResponsesRequest, request: Request):
print("request received")
- use_browser_tool = any(
- getattr(tool, "type", None) == "browser_search"
- for tool in (body.tools or [])
- )
+ tools_list = list(getattr(body, "tools", []) or [])
+ use_browser_tool = any(getattr(tool, "type", None) == "browser_search" for tool in tools_list)
+ browser_tool: Optional[SimpleBrowserTool] = None
if use_browser_tool:
- backend = ExaBackend(
- source="web",
- )
+ backend = ExaBackend(source="web")
browser_tool = SimpleBrowserTool(backend=backend)
- else:
- browser_tool = None
if body.previous_response_id:
prev = responses_store.get(body.previous_response_id)
@@ -779,11 +753,10 @@ def _ensure_list(inp):
body.instructions = prev_req.instructions
body.input = merged_input
-
system_message_content = SystemContent.new().with_conversation_start_date(
datetime.datetime.now().strftime("%Y-%m-%d")
)
-
+
if body.reasoning is not None:
try:
@@ -793,116 +766,98 @@ def _ensure_list(inp):
raise HTTPException(status_code=422, detail=str(e))
system_message_content = system_message_content.with_reasoning_effort(reasoning_effort)
- if use_browser_tool:
+ if use_browser_tool and browser_tool is not None:
system_message_content = system_message_content.with_tools(browser_tool.tool_config)
- system_message = Message.from_role_and_content(
- Role.SYSTEM, system_message_content
- )
- messages = [system_message]
+ system_message = Message.from_role_and_content(Role.SYSTEM, system_message_content)
+ messages: List[Message] = [system_message]
- if body.instructions or body.tools:
- developer_message_content = DeveloperContent.new().with_instructions(
- body.instructions
- )
+ # Developer instructions and function tools
+ if body.instructions or tools_list:
+ developer_message_content = DeveloperContent.new().with_instructions(body.instructions)
- tools = []
- for tool in body.tools:
- if tool.type == "function":
- tools.append(
- ToolDescription.new(
- tool.name,
- tool.description,
- tool.parameters,
- )
+ # Resolve conflict: build function tools once
+ function_tools: List[ToolDescription] = []
+ for tool in tools_list:
+ if getattr(tool, "type", None) == "function":
+ function_tools.append(
+ ToolDescription.new(tool.name, tool.description, tool.parameters)
)
- if tools:
- developer_message_content = developer_message_content.with_function_tools(
- tools
- )
-
- developer_message = Message.from_role_and_content(
- Role.DEVELOPER, developer_message_content
- )
+ if function_tools:
+ developer_message_content = developer_message_content.with_function_tools(function_tools)
+ developer_message = Message.from_role_and_content(Role.DEVELOPER, developer_message_content)
messages.append(developer_message)
+ # User and prior content
if isinstance(body.input, str):
user_message = Message.from_role_and_content(Role.USER, body.input)
messages.append(user_message)
else:
is_last_message_function_call_output = (
- len(body.input) > 0 and body.input[-1].type == "function_call_output"
+ len(body.input) > 0 and getattr(body.input[-1], "type", None) == "function_call_output"
)
- function_call_map = {}
- # Find the index of the last assistant message
+ function_call_map: Dict[str, FunctionCallItem] = {}
last_assistant_idx = -1
for idx, item in enumerate(body.input):
- if item.type == "message" and item.role == Role.ASSISTANT:
+ if getattr(item, "type", None) == "message" and getattr(item, "role", None) == Role.ASSISTANT:
last_assistant_idx = idx
for idx, item in enumerate(body.input):
- if item.type == "message":
- # TODO: add system prompt handling
- if isinstance(item.content, str):
- messages.append(
- Message.from_role_and_content(item.role, item.content)
- )
+ itype = getattr(item, "type", None)
+ if itype == "message":
+ content = getattr(item, "content", None)
+ role = getattr(item, "role", None)
+ if isinstance(content, str):
+ messages.append(Message.from_role_and_content(role, content))
else:
- for content_item in item.content:
- messages.append(
- Message.from_role_and_content(item.role, content_item.text)
- )
- # add final channel to the last assistant message if it's from the assistant
- if item.role == Role.ASSISTANT:
+ for content_item in content:
+ messages.append(Message.from_role_and_content(role, content_item.text))
+ if role == Role.ASSISTANT:
messages[-1] = messages[-1].with_channel("final")
- elif item.type == "reasoning":
- # Only include reasoning if it is after the last assistant message and we are handling a function call at the moment
- if (
- idx > last_assistant_idx
- and is_last_message_function_call_output
- ):
+
+ elif itype == "reasoning":
+ if idx > last_assistant_idx and is_last_message_function_call_output:
for content_item in item.content:
messages.append(
- Message.from_role_and_content(
- Role.ASSISTANT, content_item.text
- ).with_channel("analysis")
+ Message.from_role_and_content(Role.ASSISTANT, content_item.text).with_channel("analysis")
)
- elif item.type == "function_call":
- function_call_map[item.call_id] = item
+
+ elif itype == "function_call":
+ function_call_map[item.call_id] = item # type: ignore[index]
messages.append(
- Message.from_role_and_content(Role.ASSISTANT, item.arguments)
- .with_recipient(f"functions.{item.name}")
+ Message.from_role_and_content(Role.ASSISTANT, item.arguments) # type: ignore[arg-type]
+ .with_recipient(f"functions.{item.name}") # type: ignore[attr-defined]
.with_channel("commentary")
)
- elif item.type == "function_call_output":
- function_call = function_call_map.get(item.call_id, None)
+
+ elif itype == "function_call_output":
+ function_call = function_call_map.get(item.call_id) # type: ignore[index]
if not function_call:
raise ValueError(f"Function call {item.call_id} not found")
messages.append(
Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{function_call.name}"),
- item.output,
- ).with_recipient("assistant").with_channel("commentary")
+ item.output, # type: ignore[arg-type]
+ )
+ .with_recipient("assistant")
+ .with_channel("commentary")
)
conversation = Conversation.from_messages(messages)
-
- initial_tokens = encoding.render_conversation_for_completion(
- conversation, Role.ASSISTANT
- )
+ initial_tokens = encoding.render_conversation_for_completion(conversation, Role.ASSISTANT)
print(encoding.decode_utf8(initial_tokens))
response_id = f"resp_{uuid.uuid4().hex}"
- def store_callback(rid: str, req: ResponsesRequest, resp: ResponseObject):
+ def store_callback(rid: str, req: ResponsesRequest, resp: ResponseObject) -> None:
responses_store[rid] = (req, resp)
event_stream = StreamResponsesEvents(
initial_tokens,
body,
- as_sse=body.stream,
+ as_sse=bool(body.stream),
request=request,
response_id=response_id,
store_callback=store_callback,
@@ -915,7 +870,7 @@ def store_callback(rid: str, req: ResponsesRequest, resp: ResponseObject):
last_event = None
async for event in event_stream.run():
last_event = event
-
+ # last_event is a ResponseCompletedEvent
return last_event.response
return app
diff --git a/gpt_oss/responses_api/events.py b/gpt_oss/responses_api/events.py
index 7adecc6..be63cf9 100644
--- a/gpt_oss/responses_api/events.py
+++ b/gpt_oss/responses_api/events.py
@@ -3,16 +3,9 @@
from pydantic import BaseModel
-from .types import (
- FunctionCallItem,
- Item,
- ReasoningItem,
- ResponseObject,
- TextContentItem,
- ReasoningTextContentItem,
- WebSearchCallItem,
- UrlCitation,
-)
+from .types import (FunctionCallItem, Item, ReasoningItem,
+ ReasoningTextContentItem, ResponseObject, TextContentItem,
+ UrlCitation, WebSearchCallItem)
class ResponseEvent(BaseModel):
@@ -105,25 +98,37 @@ class ResponseContentPartDone(ResponseEvent):
content_index: int = 0
part: Union[TextContentItem, ReasoningTextContentItem]
+
class ResponseOutputTextAnnotationAdded(ResponseEvent):
- type: Literal["response.output_text.annotation.added"] = "response.output_text.annotation.added"
+ type: Literal["response.output_text.annotation.added"] = (
+ "response.output_text.annotation.added"
+ )
item_id: str = "item_1234"
output_index: int = 0
content_index: int = 0
annotation_index: int = 0
annotation: UrlCitation
+
class ResponseWebSearchCallInProgress(ResponseEvent):
- type: Literal["response.web_search_call.in_progress"] = "response.web_search_call.in_progress"
+ type: Literal["response.web_search_call.in_progress"] = (
+ "response.web_search_call.in_progress"
+ )
output_index: int = 0
item_id: str = "item_1234"
+
class ResponseWebSearchCallSearching(ResponseEvent):
- type: Literal["response.web_search_call.searching"] = "response.web_search_call.searching"
+ type: Literal["response.web_search_call.searching"] = (
+ "response.web_search_call.searching"
+ )
output_index: int = 0
item_id: str = "item_1234"
+
class ResponseWebSearchCallCompleted(ResponseEvent):
- type: Literal["response.web_search_call.completed"] = "response.web_search_call.completed"
+ type: Literal["response.web_search_call.completed"] = (
+ "response.web_search_call.completed"
+ )
output_index: int = 0
- item_id: str = "item_1234"
\ No newline at end of file
+ item_id: str = "item_1234"
diff --git a/gpt_oss/responses_api/inference/ollama.py b/gpt_oss/responses_api/inference/ollama.py
index 35eb1b2..43a87c7 100644
--- a/gpt_oss/responses_api/inference/ollama.py
+++ b/gpt_oss/responses_api/inference/ollama.py
@@ -8,17 +8,17 @@
import threading
import time
from typing import Callable, Optional
-import requests
-from openai_harmony import load_harmony_encoding, HarmonyEncodingName
+import requests
+from openai_harmony import HarmonyEncodingName, load_harmony_encoding
EOS_TOKEN = 200002 # only used on hard timeout
# Tunables
-POLL_INTERVAL_S = 0.01 # 10ms between buffer checks
-CALL_MAX_WAIT_S = 0.250 # max time to block inside a single infer call
-NO_TOKEN_TIMEOUT_S = 15.0 # overall inactivity timeout before emitting EOS
-FIRST_BYTE_TIMEOUT_S = 30.0 # time to wait for first token before EOS
+POLL_INTERVAL_S = 0.01 # 10ms between buffer checks
+CALL_MAX_WAIT_S = 0.250 # max time to block inside a single infer call
+NO_TOKEN_TIMEOUT_S = 15.0 # overall inactivity timeout before emitting EOS
+FIRST_BYTE_TIMEOUT_S = 30.0 # time to wait for first token before EOS
# Shared state
_token_buffer: list[int] = []
@@ -26,9 +26,10 @@
_stream_thread: Optional[threading.Thread] = None
_stream_done = threading.Event()
_stream_error: Optional[Exception] = None
-_last_progress_ts: float = 0.0 # updated whenever we enqueue or dequeue tokens
+_last_progress_ts: float = 0.0 # updated whenever we enqueue or dequeue tokens
_previous_request_tokens: list[int] = []
+
def lcp(cache: list[int], inp: list[int]) -> list[int]:
i = 0
max_len = min(len(cache), len(inp))
@@ -36,13 +37,16 @@ def lcp(cache: list[int], inp: list[int]) -> list[int]:
i += 1
return cache[:i]
+
def _now():
return time.monotonic()
+
def _touch_progress():
global _last_progress_ts
_last_progress_ts = _now()
+
def _reset_stream_state():
global _token_buffer, _stream_thread, _stream_error
with _buffer_lock:
@@ -52,14 +56,15 @@ def _reset_stream_state():
_stream_error = None
_touch_progress()
+
def setup_model(checkpoint: str) -> Callable[[list[int], float, bool], int]:
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
model_name = checkpoint
def _start_stream(token_ids: list[int], temperature: float):
prompt_text = encoding.decode(token_ids)
+
def run():
- nonlocal prompt_text, temperature
global _stream_error
global _previous_request_tokens
@@ -187,6 +192,8 @@ def infer_next_token(
# If we reach here, we still haven't got a token—ask the caller to call again soon.
# Return a harmless token that the server will replace/ignore if your interface supports it.
# If your interface does NOT allow a sentinel, keep the short-blocking behavior above.
- return EOS_TOKEN if False else 0 # replace `0` with a PAD/NOOP token your server ignores
+ return (
+ EOS_TOKEN if False else 0
+ ) # replace `0` with a PAD/NOOP token your server ignores
return infer_next_token
diff --git a/gpt_oss/responses_api/inference/transformers.py b/gpt_oss/responses_api/inference/transformers.py
index c743707..23f6691 100644
--- a/gpt_oss/responses_api/inference/transformers.py
+++ b/gpt_oss/responses_api/inference/transformers.py
@@ -6,14 +6,14 @@
import os
from typing import Callable, List
+import torch
# Transformers imports
from transformers import AutoModelForCausalLM, PreTrainedModel
-import torch
-
DEFAULT_TEMPERATURE = 0.0
TP = os.environ.get("TP", 2)
+
def load_model(checkpoint: str):
"""
Serve the model directly with the Auto API.
@@ -41,10 +41,15 @@ def get_infer_next_token(model: PreTrainedModel):
def infer_next_token(
tokens: List[int],
temperature: float = DEFAULT_TEMPERATURE,
- new_request: bool = False, # kept for interface compatibility; unused here
+ new_request: bool = False, # kept for interface compatibility; unused here
) -> int:
tokens = torch.tensor([tokens], dtype=torch.int64, device=model.device)
- output = model.generate(tokens, max_new_tokens=1, do_sample=temperature != 0, temperature=temperature)
+ output = model.generate(
+ tokens,
+ max_new_tokens=1,
+ do_sample=temperature != 0,
+ temperature=temperature,
+ )
return output[0, -1].tolist()
return infer_next_token
diff --git a/gpt_oss/responses_api/inference/vllm.py b/gpt_oss/responses_api/inference/vllm.py
index 9c07c55..b232b79 100644
--- a/gpt_oss/responses_api/inference/vllm.py
+++ b/gpt_oss/responses_api/inference/vllm.py
@@ -1,6 +1,6 @@
"""
-NOTE: this is not the most efficient way to use vLLM. It's a simple implementation that infers
-one token at a time to mimic the behavior of the Triton implementation.
+NOTE: this is not the most efficient way to use vLLM. It's a simple implementation that infers
+one token at a time to mimic the behavior of the Triton implementation.
"""
import os
@@ -13,6 +13,7 @@
DEFAULT_TEMPERATURE = 0.0
TP = os.environ.get("TP", 2)
+
def load_model(checkpoint: str):
"""
Create the vLLM engine. We enable prefix caching so repeated prefixes
@@ -21,9 +22,9 @@ def load_model(checkpoint: str):
llm = LLM(
model=checkpoint,
- tensor_parallel_size=TP, # set >1 if you want TP across GPUs
- enable_prefix_caching=True, # reuse KV for shared prefixes
- disable_log_stats=True, # uncomment to quiet logs
+ tensor_parallel_size=TP, # set >1 if you want TP across GPUs
+ enable_prefix_caching=True, # reuse KV for shared prefixes
+ disable_log_stats=True, # uncomment to quiet logs
)
return llm
@@ -52,8 +53,8 @@ def infer_next_token(
sampling = SamplingParams(
temperature=float(temperature),
- max_tokens=1, # we only want the next token
- n=1, # single continuation
+ max_tokens=1, # we only want the next token
+ n=1, # single continuation
# You can expose/enable more controls here (top_p, top_k, etc.)
)
diff --git a/gpt_oss/responses_api/serve.py b/gpt_oss/responses_api/serve.py
index 35fc3f4..fc5c19e 100644
--- a/gpt_oss/responses_api/serve.py
+++ b/gpt_oss/responses_api/serve.py
@@ -3,10 +3,7 @@
import argparse
import uvicorn
-from openai_harmony import (
- HarmonyEncodingName,
- load_harmony_encoding,
-)
+from openai_harmony import HarmonyEncodingName, load_harmony_encoding
from .api_server import create_api_server
diff --git a/gpt_oss/responses_api/types.py b/gpt_oss/responses_api/types.py
index 1d908e3..35d17a6 100644
--- a/gpt_oss/responses_api/types.py
+++ b/gpt_oss/responses_api/types.py
@@ -8,6 +8,7 @@
REASONING_EFFORT = ReasoningEffort.LOW
DEFAULT_MAX_OUTPUT_TOKENS = 10_000
+
class UrlCitation(BaseModel):
type: Literal["url_citation"]
end_index: int
@@ -15,6 +16,7 @@ class UrlCitation(BaseModel):
url: str
title: str
+
class TextContentItem(BaseModel):
type: Union[Literal["text"], Literal["input_text"], Literal["output_text"]]
text: str
@@ -61,25 +63,30 @@ class FunctionCallOutputItem(BaseModel):
call_id: str = "call_1234"
output: str
+
class WebSearchActionSearch(BaseModel):
type: Literal["search"]
query: Optional[str] = None
+
class WebSearchActionOpenPage(BaseModel):
type: Literal["open_page"]
url: Optional[str] = None
+
class WebSearchActionFind(BaseModel):
type: Literal["find"]
pattern: Optional[str] = None
url: Optional[str] = None
+
class WebSearchCallItem(BaseModel):
type: Literal["web_search_call"]
id: str = "ws_1234"
status: Literal["in_progress", "completed", "incomplete"] = "completed"
action: Union[WebSearchActionSearch, WebSearchActionOpenPage, WebSearchActionFind]
+
class Error(BaseModel):
code: str
message: str
@@ -115,7 +122,16 @@ class ResponsesRequest(BaseModel):
instructions: Optional[str] = None
max_output_tokens: Optional[int] = DEFAULT_MAX_OUTPUT_TOKENS
input: Union[
- str, list[Union[Item, ReasoningItem, FunctionCallItem, FunctionCallOutputItem, WebSearchCallItem]]
+ str,
+ list[
+ Union[
+ Item,
+ ReasoningItem,
+ FunctionCallItem,
+ FunctionCallOutputItem,
+ WebSearchCallItem,
+ ]
+ ],
]
model: Optional[str] = MODEL_IDENTIFIER
stream: Optional[bool] = False
@@ -131,7 +147,15 @@ class ResponsesRequest(BaseModel):
class ResponseObject(BaseModel):
- output: list[Union[Item, ReasoningItem, FunctionCallItem, FunctionCallOutputItem, WebSearchCallItem]]
+ output: list[
+ Union[
+ Item,
+ ReasoningItem,
+ FunctionCallItem,
+ FunctionCallOutputItem,
+ WebSearchCallItem,
+ ]
+ ]
created_at: int
usage: Optional[Usage] = None
status: Literal["completed", "failed", "incomplete", "in_progress"] = "in_progress"
diff --git a/gpt_oss/tokenizer.py b/gpt_oss/tokenizer.py
index 866077f..c12ac29 100644
--- a/gpt_oss/tokenizer.py
+++ b/gpt_oss/tokenizer.py
@@ -1,5 +1,6 @@
import tiktoken
+
def get_tokenizer():
o200k_base = tiktoken.get_encoding("o200k_base")
tokenizer = tiktoken.Encoding(
@@ -23,8 +24,7 @@ def get_tokenizer():
"<|reserved_200010|>": 200010,
"<|reserved_200011|>": 200011,
"<|call|>": 200012,
- } | {
- f"<|reserved_{i}|>": i for i in range(200013, 201088)
- },
+ }
+ | {f"<|reserved_{i}|>": i for i in range(200013, 201088)},
)
return tokenizer
diff --git a/gpt_oss/tools/apply_patch.py b/gpt_oss/tools/apply_patch.py
index a1ecb11..007199e 100644
--- a/gpt_oss/tools/apply_patch.py
+++ b/gpt_oss/tools/apply_patch.py
@@ -12,14 +12,7 @@
import pathlib
from dataclasses import dataclass, field
from enum import Enum
-from typing import (
- Callable,
- Dict,
- List,
- Optional,
- Tuple,
- Union,
-)
+from typing import Callable, Dict, List, Optional, Tuple, Union
# --------------------------------------------------------------------------- #
@@ -493,11 +486,10 @@ def remove_file(path: str) -> None:
pathlib.Path(path).unlink(missing_ok=True)
-
def apply_patch(
text: str,
open_fn: Callable[[str], str] = open_file,
- write_fn: Callable[[str, str], None] = write_file,
+ write_fn: Callable[[str, str], None] = write_file,
remove_fn: Callable[[str], None] = remove_file,
) -> str:
if not text.startswith("*** Begin Patch"):
diff --git a/gpt_oss/tools/python_docker/docker_tool.py b/gpt_oss/tools/python_docker/docker_tool.py
index 7067c1e..9c2137c 100644
--- a/gpt_oss/tools/python_docker/docker_tool.py
+++ b/gpt_oss/tools/python_docker/docker_tool.py
@@ -1,22 +1,15 @@
# Run this before running the tool:
# $ docker image pull python:3.11
+import io
+import tarfile
from typing import Any, AsyncIterator
import docker
-from openai_harmony import (
- Author,
- Content,
- Message,
- Role,
- TextContent,
- ToolNamespaceConfig,
-)
-import io
-import tarfile
+from openai_harmony import (Author, Content, Message, Role, TextContent,
+ ToolNamespaceConfig)
from ..tool import Tool
-
_docker_client = None
@@ -84,9 +77,7 @@ def instruction(self) -> str:
@property
def tool_config(self) -> ToolNamespaceConfig:
return ToolNamespaceConfig(
- name=self.get_tool_name(),
- description=self.instruction,
- tools=[]
+ name=self.get_tool_name(), description=self.instruction, tools=[]
)
def _make_response(
@@ -111,7 +102,7 @@ def make_response(
message = Message(
author=author,
content=[content],
- ).with_recipient('assistant')
+ ).with_recipient("assistant")
if channel:
message = message.with_channel(channel)
diff --git a/gpt_oss/tools/simple_browser/__init__.py b/gpt_oss/tools/simple_browser/__init__.py
index 9043cb1..b4f2628 100644
--- a/gpt_oss/tools/simple_browser/__init__.py
+++ b/gpt_oss/tools/simple_browser/__init__.py
@@ -1,5 +1,5 @@
-from .simple_browser_tool import SimpleBrowserTool
from .backend import ExaBackend
+from .simple_browser_tool import SimpleBrowserTool
__all__ = [
"SimpleBrowserTool",
diff --git a/gpt_oss/tools/simple_browser/backend.py b/gpt_oss/tools/simple_browser/backend.py
index 03bdf56..17ada2e 100644
--- a/gpt_oss/tools/simple_browser/backend.py
+++ b/gpt_oss/tools/simple_browser/backend.py
@@ -11,22 +11,12 @@
import chz
from aiohttp import ClientSession, ClientTimeout
-from tenacity import (
- after_log,
- before_sleep_log,
- retry,
- retry_if_exception_type,
- stop_after_attempt,
- wait_exponential,
-)
-
-from .page_contents import (
- Extract,
- FetchResult,
- PageContents,
- get_domain,
- process_html,
-)
+from tenacity import (after_log, before_sleep_log, retry,
+ retry_if_exception_type, stop_after_attempt,
+ wait_exponential)
+
+from .page_contents import (Extract, FetchResult, PageContents, get_domain,
+ process_html)
logger = logging.getLogger(__name__)
@@ -108,11 +98,11 @@ def _get_api_key(self) -> str:
async def _post(self, session: ClientSession, endpoint: str, payload: dict) -> dict:
headers = {"x-api-key": self._get_api_key()}
- async with session.post(f"{self.BASE_URL}{endpoint}", json=payload, headers=headers) as resp:
+ async with session.post(
+ f"{self.BASE_URL}{endpoint}", json=payload, headers=headers
+ ) as resp:
if resp.status != 200:
- raise BackendError(
- f"Exa API error {resp.status}: {await resp.text()}"
- )
+ raise BackendError(f"Exa API error {resp.status}: {await resp.text()}")
return await resp.json()
async def search(
@@ -121,7 +111,11 @@ async def search(
data = await self._post(
session,
"/search",
- {"query": query, "numResults": topn, "contents": {"text": True, "summary": True}},
+ {
+ "query": query,
+ "numResults": topn,
+ "contents": {"text": True, "summary": True},
+ },
)
# make a simple HTML page to work with browser format
titles_and_urls = [
@@ -152,7 +146,7 @@ async def fetch(self, url: str, session: ClientSession) -> PageContents:
data = await self._post(
session,
"/contents",
- {"urls": [url], "text": { "includeHtmlTags": True }},
+ {"urls": [url], "text": {"includeHtmlTags": True}},
)
results = data.get("results", [])
if not results:
diff --git a/gpt_oss/tools/simple_browser/page_contents.py b/gpt_oss/tools/simple_browser/page_contents.py
index 6fffd3f..ab81454 100644
--- a/gpt_oss/tools/simple_browser/page_contents.py
+++ b/gpt_oss/tools/simple_browser/page_contents.py
@@ -16,7 +16,6 @@
import lxml.etree
import lxml.html
import pydantic
-
import tiktoken
logger = logging.getLogger(__name__)
diff --git a/gpt_oss/tools/simple_browser/simple_browser_tool.py b/gpt_oss/tools/simple_browser/simple_browser_tool.py
index 913ee0b..bf619d9 100644
--- a/gpt_oss/tools/simple_browser/simple_browser_tool.py
+++ b/gpt_oss/tools/simple_browser/simple_browser_tool.py
@@ -12,24 +12,12 @@
import structlog
import tiktoken
from aiohttp import ClientSession
-from openai_harmony import (
- Author,
- Content,
- Message,
- Role,
- TextContent,
- ToolNamespaceConfig
-)
+from openai_harmony import (Author, Content, Message, Role, TextContent,
+ ToolNamespaceConfig)
from ..tool import Tool
-
# from functions import Function, from_python
-from .backend import (
- VIEW_SOURCE_PREFIX,
- Backend,
- BackendError,
- maybe_truncate,
-)
+from .backend import VIEW_SOURCE_PREFIX, Backend, BackendError, maybe_truncate
from .page_contents import Extract, PageContents
logger = structlog.stdlib.get_logger(component=__name__)
@@ -44,7 +32,9 @@
)
LINK_PATTERN = re.compile(r"【\d+†(?P[^†】]+)(?:†[^†】]+)?】")
-CITATION_OUTPUT_PATTERN = re.compile(r"【(?P\d+)†(?P[^†】]+)(?:†[^†】]+)?】")
+CITATION_OUTPUT_PATTERN = re.compile(
+ r"【(?P\d+)†(?P[^†】]+)(?:†[^†】]+)?】"
+)
CallParams = ParamSpec("CallParams")
@@ -352,12 +342,15 @@ def name(self) -> str:
def tool_config(self) -> ToolNamespaceConfig:
config = ToolNamespaceConfig.browser()
config.name = self.name
- config.description = """Tool for browsing.
+ config.description = (
+ """Tool for browsing.
The `cursor` appears in brackets before each browsing display: `[{cursor}]`.
Cite information from the tool using the following format:
`【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`.
Do not quote more than 10 words directly from the tool output.
-sources=""" + self.backend.source
+sources="""
+ + self.backend.source
+ )
return config
@property
@@ -616,8 +609,9 @@ def make_error_message(error: str) -> Message:
else:
raise ValueError("should not be here")
-
- def normalize_citations(self, old_content: str, hide_partial_citations: bool = False) -> tuple[str, list[dict[str, Any]], bool]:
+ def normalize_citations(
+ self, old_content: str, hide_partial_citations: bool = False
+ ) -> tuple[str, list[dict[str, Any]], bool]:
"""
Returns a tuple of (new_message, annotations, has_partial_citations)
- new_message: Message with citations replaced by ([domain](url))
@@ -625,7 +619,9 @@ def normalize_citations(self, old_content: str, hide_partial_citations: bool = F
- has_partial_citations: whether the text includes an unfinished citation
"""
- has_partial_citations = PARTIAL_FINAL_LINK_PATTERN.search(old_content) is not None
+ has_partial_citations = (
+ PARTIAL_FINAL_LINK_PATTERN.search(old_content) is not None
+ )
if hide_partial_citations and has_partial_citations:
old_content = PARTIAL_FINAL_LINK_PATTERN.sub("", old_content)
@@ -635,12 +631,14 @@ def normalize_citations(self, old_content: str, hide_partial_citations: bool = F
content = match.group("content")
start_idx = match.start()
end_idx = match.end()
- matches.append({
- "cursor": cursor,
- "content": content,
- "start": start_idx,
- "end": end_idx
- })
+ matches.append(
+ {
+ "cursor": cursor,
+ "content": content,
+ "start": start_idx,
+ "end": end_idx,
+ }
+ )
# Build a mapping from cursor to url
cursor_to_url = {}
@@ -673,13 +671,15 @@ def extract_domain(url):
# The start and end indices in the new content
start_index = len(new_content)
end_index = start_index + len(replacement)
- annotations.append({
- "start_index": start_index,
- "end_index": end_index,
- "title": domain,
- "url": url,
- "type": "url_citation",
- })
+ annotations.append(
+ {
+ "start_index": start_index,
+ "end_index": end_index,
+ "title": domain,
+ "url": url,
+ "type": "url_citation",
+ }
+ )
new_content += replacement
else:
# Keep the original citation format if cursor is missing
@@ -693,4 +693,3 @@ def extract_domain(url):
new_content += old_content[last_idx:]
return new_content, annotations, has_partial_citations
-
diff --git a/gpt_oss/tools/tool.py b/gpt_oss/tools/tool.py
index 210a3d1..374feb9 100644
--- a/gpt_oss/tools/tool.py
+++ b/gpt_oss/tools/tool.py
@@ -1,13 +1,8 @@
from abc import ABC, abstractmethod
-from uuid import UUID, uuid4
from typing import AsyncIterator
+from uuid import UUID, uuid4
-from openai_harmony import (
- Author,
- Role,
- Message,
- TextContent,
-)
+from openai_harmony import Author, Message, Role, TextContent
def _maybe_update_inplace_and_validate_channel(
@@ -63,13 +58,17 @@ async def process(self, message: Message) -> AsyncIterator[Message]:
"""
async for m in self._process(message):
if self.output_channel_should_match_input_channel:
- _maybe_update_inplace_and_validate_channel(input_message=message, tool_message=m)
+ _maybe_update_inplace_and_validate_channel(
+ input_message=message, tool_message=m
+ )
yield m
@abstractmethod
async def _process(self, message: Message) -> AsyncIterator[Message]:
"""Override this method to provide the implementation of the tool."""
- if False: # This is to convince the type checker that this is an async generator.
+ if (
+ False
+ ): # This is to convince the type checker that this is an async generator.
yield # type: ignore[unreachable]
_ = message # Stifle "unused argument" warning.
raise NotImplementedError
@@ -94,7 +93,6 @@ def error_message(
return Message(
id=id if id else uuid4(),
author=Author(role=Role.TOOL, name=self.name),
- content=TextContent(text=error_message), # TODO: Use SystemError instead
+ content=TextContent(text=f"Error: {error_message}"),
channel=channel,
).with_recipient("assistant")
-
diff --git a/gpt_oss/torch/model.py b/gpt_oss/torch/model.py
index 9180d49..3c473e0 100644
--- a/gpt_oss/torch/model.py
+++ b/gpt_oss/torch/model.py
@@ -420,8 +420,10 @@ def from_checkpoint(
if "mlp1" in name: # both weight and bias
loaded_tensor = loaded_tensor[
:,
- my_rank * 2
- * per_rank_intermediate_size : (my_rank + 1) * 2
+ my_rank
+ * 2
+ * per_rank_intermediate_size : (my_rank + 1)
+ * 2
* per_rank_intermediate_size,
...,
]
@@ -448,16 +450,20 @@ def __init__(self, checkpoint: str, device: torch.device):
self.model = Transformer.from_checkpoint(checkpoint, device=self.device)
@torch.inference_mode()
- def generate(self,
- prompt_tokens: list[int],
- stop_tokens: list[int],
- temperature: float = 1.0,
- max_tokens: int = 0,
- return_logprobs: bool = False):
+ def generate(
+ self,
+ prompt_tokens: list[int],
+ stop_tokens: list[int],
+ temperature: float = 1.0,
+ max_tokens: int = 0,
+ return_logprobs: bool = False,
+ ):
tokens = list(prompt_tokens)
num_generated_tokens = 0
while max_tokens == 0 or num_generated_tokens < max_tokens:
- logits = self.model(torch.as_tensor(tokens, dtype=torch.int32, device=self.device))[-1]
+ logits = self.model(
+ torch.as_tensor(tokens, dtype=torch.int32, device=self.device)
+ )[-1]
if temperature == 0.0:
predicted_token = torch.argmax(logits, dim=-1).item()
else:
diff --git a/gpt_oss/torch/utils.py b/gpt_oss/torch/utils.py
index ce87a85..10ce783 100644
--- a/gpt_oss/torch/utils.py
+++ b/gpt_oss/torch/utils.py
@@ -1,4 +1,5 @@
import os
+
import torch
import torch.distributed as dist
@@ -6,10 +7,11 @@
def suppress_output(rank):
"""Suppress printing on the current device. Force printing with `force=True`."""
import builtins as __builtin__
+
builtin_print = __builtin__.print
def print(*args, **kwargs):
- force = kwargs.pop('force', False)
+ force = kwargs.pop("force", False)
if force:
builtin_print("rank #%d:" % rank, *args, **kwargs)
elif rank == 0:
diff --git a/gpt_oss/torch/weights.py b/gpt_oss/torch/weights.py
index aa5df58..1fd6867 100644
--- a/gpt_oss/torch/weights.py
+++ b/gpt_oss/torch/weights.py
@@ -4,25 +4,47 @@
import torch
from safetensors import safe_open
-
# Bytes per MXFP4 block: 32 FP4 numbers packed in 16 bytes
BYTES_PER_BLOCK = 16
FP4_VALUES = [
- +0.0, +0.5, +1.0, +1.5, +2.0, +3.0, +4.0, +6.0,
- -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
+ +0.0,
+ +0.5,
+ +1.0,
+ +1.5,
+ +2.0,
+ +3.0,
+ +4.0,
+ +6.0,
+ -0.0,
+ -0.5,
+ -1.0,
+ -1.5,
+ -2.0,
+ -3.0,
+ -4.0,
+ -6.0,
]
# Map the names assumed in this implementation to the checkpoint names.
-PARAM_NAME_MAP = {
- f"block.{n}.mlp.mlp1_bias": f"block.{n}.mlp.mlp1_bias" for n in range(36)
-} | {
- f"block.{n}.mlp.mlp1_weight": (f"block.{n}.mlp.mlp1_weight.blocks", f"block.{n}.mlp.mlp1_weight.scales") for n in range(36)
-} | {
- f"block.{n}.mlp.mlp2_bias": f"block.{n}.mlp.mlp2_bias" for n in range(36)
-} | {
- f"block.{n}.mlp.mlp2_weight": (f"block.{n}.mlp.mlp2_weight.blocks", f"block.{n}.mlp.mlp2_weight.scales") for n in range(36)
-}
+PARAM_NAME_MAP = (
+ {f"block.{n}.mlp.mlp1_bias": f"block.{n}.mlp.mlp1_bias" for n in range(36)}
+ | {
+ f"block.{n}.mlp.mlp1_weight": (
+ f"block.{n}.mlp.mlp1_weight.blocks",
+ f"block.{n}.mlp.mlp1_weight.scales",
+ )
+ for n in range(36)
+ }
+ | {f"block.{n}.mlp.mlp2_bias": f"block.{n}.mlp.mlp2_bias" for n in range(36)}
+ | {
+ f"block.{n}.mlp.mlp2_weight": (
+ f"block.{n}.mlp.mlp2_weight.blocks",
+ f"block.{n}.mlp.mlp2_weight.scales",
+ )
+ for n in range(36)
+ }
+)
class Checkpoint:
@@ -53,13 +75,17 @@ def get(self, name: str) -> torch.Tensor:
match PARAM_NAME_MAP.get(name, name):
case (blocks_name, scales_name):
# MoE weights: are in block-based MXFP4 format
- return self._get_mxfp4_tensor(blocks_name, scales_name, dtype=torch.bfloat16)
+ return self._get_mxfp4_tensor(
+ blocks_name, scales_name, dtype=torch.bfloat16
+ )
case tensor_name:
# MoE biases and other weights
return self._get_tensor(tensor_name)
def _get_tensor(self, name: str) -> str:
- assert name in self.tensor_name_to_file, f"Tensor {name} not found in checkpoint."
+ assert (
+ name in self.tensor_name_to_file
+ ), f"Tensor {name} not found in checkpoint."
with safe_open(
self.tensor_name_to_file[name], framework="pt", device=self.device_str
) as f:
@@ -73,24 +99,24 @@ def _get_mxfp4_tensor(
dtype: torch.dtype = torch.bfloat16,
rows_per_chunk: int = 16384 * 512,
) -> torch.Tensor:
- assert blocks_name in self.tensor_name_to_file, (
- f"Blocks tensor {blocks_name} not found in checkpoint."
- )
- assert scales_name in self.tensor_name_to_file, (
- f"Scales tensor {scales_name} not found in checkpoint."
- )
+ assert (
+ blocks_name in self.tensor_name_to_file
+ ), f"Blocks tensor {blocks_name} not found in checkpoint."
+ assert (
+ scales_name in self.tensor_name_to_file
+ ), f"Scales tensor {scales_name} not found in checkpoint."
blocks = self._get_tensor(blocks_name)
scales = self._get_tensor(scales_name).to(torch.int32) - 127
- assert blocks.shape[:-1] == scales.shape, (
- f"{blocks.shape=} does not match {scales.shape=}"
- )
+ assert (
+ blocks.shape[:-1] == scales.shape
+ ), f"{blocks.shape=} does not match {scales.shape=}"
lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)
*prefix_shape, G, B = blocks.shape
- rows_total = math.prod(prefix_shape) * G
+ rows_total = math.prod(prefix_shape) * G
blocks = blocks.reshape(rows_total, B)
scales = scales.reshape(rows_total, 1)
@@ -116,7 +142,9 @@ def _get_mxfp4_tensor(
return out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)
- def _get_mxfp4_tensor_copy(self, blocks_name: str, scales_name: str, dtype: torch.dtype = torch.bfloat16):
+ def _get_mxfp4_tensor_copy(
+ self, blocks_name: str, scales_name: str, dtype: torch.dtype = torch.bfloat16
+ ):
"short version that uses a lot of memory"
loaded_blocks = self._get_tensor(blocks_name)
@@ -124,7 +152,9 @@ def _get_mxfp4_tensor_copy(self, blocks_name: str, scales_name: str, dtype: torc
loaded_blocks_lo = loaded_blocks & 0x0F
loaded_blocks_hi = loaded_blocks >> 4
loaded_blocks = torch.stack((loaded_blocks_lo, loaded_blocks_hi), dim=-1)
- loaded_blocks = loaded_blocks.view(*loaded_blocks.shape[:-2], loaded_blocks.shape[-2] * 2)
+ loaded_blocks = loaded_blocks.view(
+ *loaded_blocks.shape[:-2], loaded_blocks.shape[-2] * 2
+ )
loaded_scales = self._get_tensor(scales_name)
# Upcast to int32 and subtract bias
@@ -132,6 +162,8 @@ def _get_mxfp4_tensor_copy(self, blocks_name: str, scales_name: str, dtype: torc
# Convert MXFP4 numbers into target dtype
fp4_values = torch.tensor(FP4_VALUES, dtype=dtype, device=self.device_str)
- loaded_tensor = torch.ldexp(fp4_values[loaded_blocks.int()], loaded_scales.unsqueeze(-1))
+ loaded_tensor = torch.ldexp(
+ fp4_values[loaded_blocks.int()], loaded_scales.unsqueeze(-1)
+ )
loaded_tensor = loaded_tensor.view(*loaded_tensor.shape[:-2], -1)
return loaded_tensor
diff --git a/gpt_oss/triton/attention.py b/gpt_oss/triton/attention.py
index e9222f0..8f7b845 100644
--- a/gpt_oss/triton/attention.py
+++ b/gpt_oss/triton/attention.py
@@ -8,7 +8,6 @@
import pytest
import torch
-
import triton
import triton.language as tl
@@ -111,7 +110,10 @@ def _attn_fwd(
q = tl.load(Q_block_ptr)
if BANDWIDTH:
- lo, hi = tl.maximum(start_q, start_q + start_m * BLOCK_M - BANDWIDTH), (start_q + start_m + 1) * BLOCK_M
+ lo, hi = (
+ tl.maximum(start_q, start_q + start_m * BLOCK_M - BANDWIDTH),
+ (start_q + start_m + 1) * BLOCK_M,
+ )
else:
lo, hi = start_q, (start_q + start_m + 1) * BLOCK_M
@@ -125,7 +127,9 @@ def _attn_fwd(
mask = (start_n + offs_n)[None, :] > (start_q + offs_m)[:, None]
if BANDWIDTH:
- too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1)
+ too_old = (start_n + offs_n[None, :]) < (
+ start_q + offs_m[:, None] - BANDWIDTH + 1
+ )
mask = mask | too_old
k = tl.load(K_block_ptr)
@@ -186,7 +190,9 @@ def forward(ctx, q, k, v, sinks, sm_scale, bandwidth, start_q):
v = torch.nn.functional.pad(v, (0, 0, 0, n_pad_size))
o = torch.empty_like(q)
- M = torch.empty((bs, n_heads, n_ctx + m_pad_size), device=q.device, dtype=torch.float32)
+ M = torch.empty(
+ (bs, n_heads, n_ctx + m_pad_size), device=q.device, dtype=torch.float32
+ )
grid = (triton.cdiv(n_ctx, BLOCK_M), bs * n_heads, 1)
_attn_fwd[grid](
q,
@@ -244,7 +250,9 @@ def attention_ref(
sliding_window: int | None = None,
start_q: torch.LongTensor = 0,
):
- batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape
+ batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = (
+ query.shape
+ )
batch_size, num_keys, num_key_value_heads, head_dim = key.shape
sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float()
@@ -272,7 +280,9 @@ def attention_ref(
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
- output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups * head_dim).bfloat16()
+ output = output.reshape(
+ batch_size, num_queries, num_key_value_heads * num_key_value_groups * head_dim
+ ).bfloat16()
return output
@@ -285,13 +295,37 @@ def attention_ref(
@pytest.mark.parametrize("sm_scale", [0.125])
@pytest.mark.parametrize("sliding_window", [None, 128])
@pytest.mark.parametrize("start_q", [0, 5])
-def test_eq(batch_size, num_queries, num_keys, num_key_value_heads, num_key_value_groups, head_dim, sm_scale, sliding_window, start_q):
+def test_eq(
+ batch_size,
+ num_queries,
+ num_keys,
+ num_key_value_heads,
+ num_key_value_groups,
+ head_dim,
+ sm_scale,
+ sliding_window,
+ start_q,
+):
if num_queries > num_keys:
pytest.skip("too many queries")
- q = torch.randn(batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim).bfloat16().cuda()
- k = torch.randn(batch_size, num_keys, num_key_value_heads, head_dim).bfloat16().cuda()
- v = torch.randn(batch_size, num_keys, num_key_value_heads, head_dim).bfloat16().cuda()
+ q = (
+ torch.randn(
+ batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim
+ )
+ .bfloat16()
+ .cuda()
+ )
+ k = (
+ torch.randn(batch_size, num_keys, num_key_value_heads, head_dim)
+ .bfloat16()
+ .cuda()
+ )
+ v = (
+ torch.randn(batch_size, num_keys, num_key_value_heads, head_dim)
+ .bfloat16()
+ .cuda()
+ )
sinks = torch.randn(num_key_value_heads * num_key_value_groups).bfloat16().cuda()
start_q = torch.tensor([start_q], dtype=torch.int32).cuda()
@@ -299,4 +333,4 @@ def test_eq(batch_size, num_queries, num_keys, num_key_value_heads, num_key_valu
o1 = attention(q, k, v, sinks, sm_scale, sliding_window, start_q)
o2 = attention_ref(q, k, v, sinks, sm_scale, sliding_window, start_q)
- torch.testing.assert_close(o1, o2)
\ No newline at end of file
+ torch.testing.assert_close(o1, o2)
diff --git a/gpt_oss/triton/model.py b/gpt_oss/triton/model.py
index 2e14478..008fbb2 100644
--- a/gpt_oss/triton/model.py
+++ b/gpt_oss/triton/model.py
@@ -8,7 +8,7 @@
from gpt_oss.torch.model import ModelConfig, RMSNorm
from gpt_oss.torch.weights import Checkpoint
from gpt_oss.triton.attention import attention, attention_ref
-from gpt_oss.triton.moe import quantize_mx4, moe
+from gpt_oss.triton.moe import moe, quantize_mx4
class RotaryEmbedding(torch.nn.Module):
@@ -78,7 +78,9 @@ def _compute_concentration_and_inv_freq(self) -> torch.Tensor:
def _compute_cos_sin(self, start: int, num_tokens: int):
concentration, inv_freq = self._compute_concentration_and_inv_freq()
- t = torch.arange(start, start + num_tokens, dtype=torch.float32, device=self.device)
+ t = torch.arange(
+ start, start + num_tokens, dtype=torch.float32, device=self.device
+ )
freqs = torch.einsum("i,j->ij", t, inv_freq)
cos = freqs.cos() * concentration
sin = freqs.sin() * concentration
@@ -119,9 +121,20 @@ def forward(
class Cache:
- def __init__(self, batch_size, n_ctx, n_kv_heads, d_head=64, device: torch.device | None = None):
- self.k = torch.zeros((batch_size, n_ctx, n_kv_heads, d_head), dtype=torch.bfloat16, device=device)
- self.v = torch.zeros((batch_size, n_ctx, n_kv_heads, d_head), dtype=torch.bfloat16, device=device)
+ def __init__(
+ self,
+ batch_size,
+ n_ctx,
+ n_kv_heads,
+ d_head=64,
+ device: torch.device | None = None,
+ ):
+ self.k = torch.zeros(
+ (batch_size, n_ctx, n_kv_heads, d_head), dtype=torch.bfloat16, device=device
+ )
+ self.v = torch.zeros(
+ (batch_size, n_ctx, n_kv_heads, d_head), dtype=torch.bfloat16, device=device
+ )
self.offset = torch.zeros((1,), dtype=torch.long, device=device)
def reset(self):
@@ -147,7 +160,9 @@ def truncate(self, n_ctx):
def extend(self, k, v):
batch_size, n_ctx, *_rest = k.shape
assert batch_size == self.k.shape[0]
- indices = torch.arange(0, n_ctx, device=k.device, dtype=torch.long) + self.offset
+ indices = (
+ torch.arange(0, n_ctx, device=k.device, dtype=torch.long) + self.offset
+ )
self.k.index_copy_(1, indices, k)
self.v.index_copy_(1, indices, v)
self.offset.add_(n_ctx)
@@ -206,7 +221,7 @@ def forward(self, x: torch.Tensor, cache: Cache | None = None) -> torch.Tensor:
qkv_parts = (
self.num_attention_heads * self.head_dim,
self.num_key_value_heads * self.head_dim,
- self.num_key_value_heads * self.head_dim
+ self.num_key_value_heads * self.head_dim,
)
q, k, v = torch.split(qkv, qkv_parts, dim=-1)
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
@@ -283,22 +298,24 @@ def __init__(
self.experts_per_token = config.experts_per_token
self.swiglu_limit = config.swiglu_limit
self.norm = RMSNorm(config.hidden_size, device=device)
- self.gate = torch.nn.ParameterDict({
- "weight": torch.nn.Parameter(
- torch.empty(
- (config.hidden_size, config.num_experts),
- device=device,
- dtype=torch.bfloat16,
- )
- ),
- "bias": torch.nn.Parameter(
- torch.empty(
- (config.num_experts,),
- device=device,
- dtype=torch.bfloat16,
- )
- ),
- })
+ self.gate = torch.nn.ParameterDict(
+ {
+ "weight": torch.nn.Parameter(
+ torch.empty(
+ (config.hidden_size, config.num_experts),
+ device=device,
+ dtype=torch.bfloat16,
+ )
+ ),
+ "bias": torch.nn.Parameter(
+ torch.empty(
+ (config.num_experts,),
+ device=device,
+ dtype=torch.bfloat16,
+ )
+ ),
+ }
+ )
self.mlp1_weight_tensor, self.mlp1_weight_mx = quantize_mx4(
torch.empty(
(
@@ -310,7 +327,9 @@ def __init__(
dtype=torch.bfloat16,
),
)
- self.mlp1_weight = torch.nn.Parameter(self.mlp1_weight_tensor.storage.data, requires_grad=False)
+ self.mlp1_weight = torch.nn.Parameter(
+ self.mlp1_weight_tensor.storage.data, requires_grad=False
+ )
self.mlp1_bias = torch.nn.Parameter(
torch.empty(
(config.num_experts, config.intermediate_size * 2),
@@ -329,7 +348,9 @@ def __init__(
dtype=torch.bfloat16,
),
)
- self.mlp2_weight = torch.nn.Parameter(self.mlp2_weight_tensor.storage.data, requires_grad=False)
+ self.mlp2_weight = torch.nn.Parameter(
+ self.mlp2_weight_tensor.storage.data, requires_grad=False
+ )
self.mlp2_bias = torch.nn.Parameter(
torch.empty(
(config.num_experts, config.hidden_size),
@@ -347,8 +368,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
t = moe(
t,
self.gate["weight"],
- self.mlp1_weight_tensor, self.mlp1_weight_mx,
- self.mlp2_weight_tensor, self.mlp2_weight_mx,
+ self.mlp1_weight_tensor,
+ self.mlp1_weight_mx,
+ self.mlp2_weight_tensor,
+ self.mlp2_weight_mx,
self.gate["bias"].float(),
self.mlp1_bias.float(),
self.mlp2_bias.float(),
@@ -405,8 +428,10 @@ def __init__(
dtype=torch.bfloat16,
)
- def forward(self, x: torch.Tensor, caches: list[Cache] | None = None) -> torch.Tensor:
- caches=caches or [None] * len(self.block)
+ def forward(
+ self, x: torch.Tensor, caches: list[Cache] | None = None
+ ) -> torch.Tensor:
+ caches = caches or [None] * len(self.block)
with record_function("embedding"):
x = self.embedding(x)
for block, cache in zip(self.block, caches):
@@ -420,7 +445,9 @@ def forward(self, x: torch.Tensor, caches: list[Cache] | None = None) -> torch.T
@staticmethod
def from_checkpoint(
- path: str, config: ModelConfig | None = None, device: str | torch.device = "cuda",
+ path: str,
+ config: ModelConfig | None = None,
+ device: str | torch.device = "cuda",
) -> "Transformer":
if not isinstance(device, torch.device):
device = torch.device(device)
@@ -472,7 +499,10 @@ class TokenGenerator:
def __init__(self, checkpoint: str, context: int, device: torch.device):
self.device = device
self.model = Transformer.from_checkpoint(checkpoint, device=self.device)
- self.caches = [Cache(1, context, self.model.config.num_key_value_heads, device=self.device) for _ in range(len(self.model.block))]
+ self.caches = [
+ Cache(1, context, self.model.config.num_key_value_heads, device=self.device)
+ for _ in range(len(self.model.block))
+ ]
self.input_token = torch.zeros(1, dtype=torch.int32, device=self.device)
# warmup
self.model(self.input_token[None, :], caches=self.caches)
@@ -482,16 +512,20 @@ def __init__(self, checkpoint: str, context: int, device: torch.device):
self.logits = self.model(self.input_token[None, :], caches=self.caches)[0]
@torch.inference_mode()
- def generate(self,
- prompt_tokens: list[int],
- stop_tokens: list[int] | None = None,
- temperature: float = 1.0,
- max_tokens: int = 0,
- return_logprobs: bool = False):
+ def generate(
+ self,
+ prompt_tokens: list[int],
+ stop_tokens: list[int] | None = None,
+ temperature: float = 1.0,
+ max_tokens: int = 0,
+ return_logprobs: bool = False,
+ ):
stop_tokens = stop_tokens or []
for cache in self.caches:
cache.reset()
- prompt_tokens = torch.as_tensor(prompt_tokens, dtype=torch.int32, device=self.device)
+ prompt_tokens = torch.as_tensor(
+ prompt_tokens, dtype=torch.int32, device=self.device
+ )
self.model(prompt_tokens[None, :-1], self.caches)
predicted_token = prompt_tokens[-1]
num_generated_tokens = 0
diff --git a/gpt_oss/triton/moe.py b/gpt_oss/triton/moe.py
index 925dbd5..de32c72 100644
--- a/gpt_oss/triton/moe.py
+++ b/gpt_oss/triton/moe.py
@@ -1,16 +1,16 @@
import torch
-from torch.profiler import record_function
-
import triton_kernels
import triton_kernels.swiglu
-from triton_kernels.numerics_details.mxfp import downcast_to_mxfp
-from triton_kernels.matmul_ogs import PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
-from triton_kernels.matmul_ogs import matmul_ogs
+from torch.profiler import record_function
+from triton_kernels.matmul_ogs import (FlexCtx, FnSpecs, FusedActivation,
+ PrecisionConfig, matmul_ogs)
from triton_kernels.numerics import InFlexData
+from triton_kernels.numerics_details.mxfp import downcast_to_mxfp
from triton_kernels.routing import routing
-from triton_kernels.tensor import convert_layout
-from triton_kernels.tensor_details.layout import StridedLayout, HopperMXScaleLayout, HopperMXValueLayout
-from triton_kernels.tensor import wrap_torch_tensor, FP4
+from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
+from triton_kernels.tensor_details.layout import (HopperMXScaleLayout,
+ HopperMXValueLayout,
+ StridedLayout)
def quantize_mx4(w):
@@ -31,7 +31,22 @@ def swiglu(x, alpha: float = 1.702, limit: float = 7.0, interleaved: bool = True
return out_glu * (x_linear + 1)
-def moe(x, wg, w1, w1_mx, w2, w2_mx, bg, b1, b2, experts_per_token=4, num_experts=128, swiglu_limit=7.0, fused_act=True, interleaved=True):
+def moe(
+ x,
+ wg,
+ w1,
+ w1_mx,
+ w2,
+ w2_mx,
+ bg,
+ b1,
+ b2,
+ experts_per_token=4,
+ num_experts=128,
+ swiglu_limit=7.0,
+ fused_act=True,
+ interleaved=True,
+):
if x.numel() == 0:
return x
@@ -42,19 +57,43 @@ def moe(x, wg, w1, w1_mx, w2, w2_mx, bg, b1, b2, experts_per_token=4, num_expert
with record_function("wg"):
logits = matmul_ogs(x, wg, bg, precision_config=pcg)
with record_function("routing"):
- rdata, gather_indx, scatter_indx = routing(logits, experts_per_token, simulated_ep=1)
+ rdata, gather_indx, scatter_indx = routing(
+ logits, experts_per_token, simulated_ep=1
+ )
if fused_act:
assert interleaved, "Fused activation requires interleaved weights"
with record_function("w1+swiglu"):
- act = FusedActivation(FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), (1.702, swiglu_limit), 2)
- x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act)
+ act = FusedActivation(
+ FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
+ (1.702, swiglu_limit),
+ 2,
+ )
+ x = matmul_ogs(
+ x,
+ w1,
+ b1,
+ rdata,
+ gather_indx=gather_indx,
+ precision_config=pc1,
+ fused_activation=act,
+ )
else:
with record_function("w1"):
- x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1)
+ x = matmul_ogs(
+ x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1
+ )
with record_function("swiglu"):
x = swiglu(x, limit=swiglu_limit, interleaved=interleaved)
with record_function("w2"):
- x = matmul_ogs(x, w2, b2, rdata, scatter_indx=scatter_indx, precision_config=pc2, gammas=rdata.gate_scal)
+ x = matmul_ogs(
+ x,
+ w2,
+ b2,
+ rdata,
+ scatter_indx=scatter_indx,
+ precision_config=pc2,
+ gammas=rdata.gate_scal,
+ )
return x
diff --git a/gpt_oss/vllm/token_generator.py b/gpt_oss/vllm/token_generator.py
index 000f322..8b65a76 100644
--- a/gpt_oss/vllm/token_generator.py
+++ b/gpt_oss/vllm/token_generator.py
@@ -1,4 +1,4 @@
-from vllm import LLMEngine, EngineArgs, SamplingParams, TokensPrompt
+from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
class TokenGenerator:
@@ -10,20 +10,24 @@ def __init__(self, model_path: str, tensor_parallel_size: int = 1):
self.engine = LLMEngine.from_engine_args(args)
self.request_id = 0
- def generate(self,
- prompt_tokens: list[int],
- stop_tokens: list[int] | None = None,
- temperature: float = 1.0,
- max_tokens: int = 0,
- return_logprobs: bool = False):
+ def generate(
+ self,
+ prompt_tokens: list[int],
+ stop_tokens: list[int] | None = None,
+ temperature: float = 1.0,
+ max_tokens: int = 0,
+ return_logprobs: bool = False,
+ ):
if max_tokens == 0:
max_tokens = None
request_id = str(self.request_id)
self.request_id += 1
- sampling_params = SamplingParams(temperature=temperature,
- max_tokens=max_tokens,
- stop_token_ids=stop_tokens,
- logprobs=0 if return_logprobs else None)
+ sampling_params = SamplingParams(
+ temperature=temperature,
+ max_tokens=max_tokens,
+ stop_token_ids=stop_tokens,
+ logprobs=0 if return_logprobs else None,
+ )
prompt = TokensPrompt(prompt_token_ids=prompt_tokens)
self.engine.add_request(request_id, prompt, sampling_params)
last_token_id = []
@@ -32,8 +36,12 @@ def generate(self,
output = step_outputs[0].outputs[0]
token_ids = output.token_ids
logprobs_list = output.logprobs if hasattr(output, "logprobs") else None
- new_token_ids = token_ids[len(last_token_id):]
- new_logprobs = logprobs_list[len(last_token_id):] if logprobs_list is not None else [None] * len(new_token_ids)
+ new_token_ids = token_ids[len(last_token_id) :]
+ new_logprobs = (
+ logprobs_list[len(last_token_id) :]
+ if logprobs_list is not None
+ else [None] * len(new_token_ids)
+ )
for token_id, logprobs in zip(new_token_ids, new_logprobs):
last_token_id.append(token_id)
if return_logprobs:
diff --git a/tests/conftest.py b/tests/conftest.py
index 4c008a3..47f4bfb 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,16 +1,15 @@
import os
import sys
+from typing import Any, Generator
+from unittest.mock import MagicMock, Mock
+
import pytest
-from typing import Generator, Any
-from unittest.mock import Mock, MagicMock
from fastapi.testclient import TestClient
-sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
+
+from openai_harmony import HarmonyEncodingName, load_harmony_encoding
-from openai_harmony import (
- HarmonyEncodingName,
- load_harmony_encoding,
-)
from gpt_oss.responses_api.api_server import create_api_server
@@ -22,24 +21,25 @@ def harmony_encoding():
@pytest.fixture
def mock_infer_token(harmony_encoding):
fake_tokens = harmony_encoding.encode(
- "<|channel|>final<|message|>Test response<|return|>",
- allowed_special="all"
+ "<|channel|>final<|message|>Test response<|return|>", allowed_special="all"
)
token_queue = fake_tokens.copy()
-
- def _mock_infer(tokens: list[int], temperature: float = 0.0, new_request: bool = False) -> int:
+
+ def _mock_infer(
+ tokens: list[int], temperature: float = 0.0, new_request: bool = False
+ ) -> int:
nonlocal token_queue
if len(token_queue) == 0:
token_queue = fake_tokens.copy()
return token_queue.pop(0)
+
return _mock_infer
@pytest.fixture
def api_client(harmony_encoding, mock_infer_token) -> Generator[TestClient, None, None]:
app = create_api_server(
- infer_next_token=mock_infer_token,
- encoding=harmony_encoding
+ infer_next_token=mock_infer_token, encoding=harmony_encoding
)
with TestClient(app) as client:
yield client
@@ -53,7 +53,7 @@ def sample_request_data():
"stream": False,
"reasoning_effort": "low",
"temperature": 0.7,
- "tools": []
+ "tools": [],
}
@@ -72,23 +72,23 @@ def mock_python_tool():
mock.execute.return_value = {
"output": "print('Hello')",
"error": None,
- "exit_code": 0
+ "exit_code": 0,
}
return mock
@pytest.fixture(autouse=True)
def reset_test_environment():
- test_env_vars = ['OPENAI_API_KEY', 'GPT_OSS_MODEL_PATH']
+ test_env_vars = ["OPENAI_API_KEY", "GPT_OSS_MODEL_PATH"]
original_values = {}
-
+
for var in test_env_vars:
if var in os.environ:
original_values[var] = os.environ[var]
del os.environ[var]
-
+
yield
-
+
for var, value in original_values.items():
os.environ[var] = value
@@ -96,23 +96,23 @@ def reset_test_environment():
@pytest.fixture
def performance_timer():
import time
-
+
class Timer:
def __init__(self):
self.start_time = None
self.end_time = None
-
+
def start(self):
self.start_time = time.time()
-
+
def stop(self):
self.end_time = time.time()
return self.elapsed
-
+
@property
def elapsed(self):
if self.start_time and self.end_time:
return self.end_time - self.start_time
return None
-
- return Timer()
\ No newline at end of file
+
+ return Timer()
diff --git a/tests/test_api_endpoints.py b/tests/test_api_endpoints.py
index 7fd354b..fb27eba 100644
--- a/tests/test_api_endpoints.py
+++ b/tests/test_api_endpoints.py
@@ -1,12 +1,13 @@
-import pytest
-import json
import asyncio
+import json
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
from fastapi import status
-from unittest.mock import patch, MagicMock, AsyncMock
class TestResponsesEndpoint:
-
+
def test_basic_response_creation(self, api_client, sample_request_data):
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK
@@ -14,7 +15,7 @@ def test_basic_response_creation(self, api_client, sample_request_data):
assert "id" in data
assert data["object"] == "response"
assert data["model"] == sample_request_data["model"]
-
+
def test_response_with_high_reasoning(self, api_client, sample_request_data):
sample_request_data["reasoning_effort"] = "high"
response = api_client.post("/v1/responses", json=sample_request_data)
@@ -22,7 +23,7 @@ def test_response_with_high_reasoning(self, api_client, sample_request_data):
data = response.json()
assert "id" in data
assert data["status"] == "completed"
-
+
def test_response_with_medium_reasoning(self, api_client, sample_request_data):
sample_request_data["reasoning_effort"] = "medium"
response = api_client.post("/v1/responses", json=sample_request_data)
@@ -30,27 +31,23 @@ def test_response_with_medium_reasoning(self, api_client, sample_request_data):
data = response.json()
assert "id" in data
assert data["status"] == "completed"
-
+
def test_response_with_invalid_model(self, api_client, sample_request_data):
sample_request_data["model"] = "invalid-model"
response = api_client.post("/v1/responses", json=sample_request_data)
# Should still accept but might handle differently
assert response.status_code == status.HTTP_200_OK
-
+
def test_response_with_empty_input(self, api_client, sample_request_data):
sample_request_data["input"] = ""
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK
-
+
def test_response_with_tools(self, api_client, sample_request_data):
- sample_request_data["tools"] = [
- {
- "type": "browser_search"
- }
- ]
+ sample_request_data["tools"] = [{"type": "browser_search"}]
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK
-
+
def test_response_with_custom_temperature(self, api_client, sample_request_data):
for temp in [0.0, 0.5, 1.0, 1.5, 2.0]:
sample_request_data["temperature"] = temp
@@ -58,10 +55,12 @@ def test_response_with_custom_temperature(self, api_client, sample_request_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "usage" in data
-
+
def test_streaming_response(self, api_client, sample_request_data):
sample_request_data["stream"] = True
- with api_client.stream("POST", "/v1/responses", json=sample_request_data) as response:
+ with api_client.stream(
+ "POST", "/v1/responses", json=sample_request_data
+ ) as response:
assert response.status_code == status.HTTP_200_OK
# Verify we get SSE events
for line in response.iter_lines():
@@ -73,63 +72,66 @@ def test_streaming_response(self, api_client, sample_request_data):
class TestResponsesWithSession:
-
+
def test_response_with_session_id(self, api_client, sample_request_data):
session_id = "test-session-123"
sample_request_data["session_id"] = session_id
-
+
# First request
response1 = api_client.post("/v1/responses", json=sample_request_data)
assert response1.status_code == status.HTTP_200_OK
data1 = response1.json()
-
+
# Second request with same session
sample_request_data["input"] = "Follow up question"
response2 = api_client.post("/v1/responses", json=sample_request_data)
assert response2.status_code == status.HTTP_200_OK
data2 = response2.json()
-
+
# Should have different response IDs
assert data1["id"] != data2["id"]
-
+
def test_response_continuation(self, api_client, sample_request_data):
# Create initial response
response1 = api_client.post("/v1/responses", json=sample_request_data)
assert response1.status_code == status.HTTP_200_OK
data1 = response1.json()
response_id = data1["id"]
-
+
# Continue the response
continuation_request = {
"model": sample_request_data["model"],
"response_id": response_id,
- "input": "Continue the previous thought"
+ "input": "Continue the previous thought",
}
response2 = api_client.post("/v1/responses", json=continuation_request)
assert response2.status_code == status.HTTP_200_OK
class TestErrorHandling:
-
+
def test_missing_required_fields(self, api_client):
# Model field has default, so test with empty JSON
response = api_client.post("/v1/responses", json={})
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
-
+
def test_invalid_reasoning_effort(self, api_client, sample_request_data):
sample_request_data["reasoning_effort"] = "invalid"
response = api_client.post("/v1/responses", json=sample_request_data)
# May handle gracefully or return error
- assert response.status_code in [status.HTTP_200_OK, status.HTTP_422_UNPROCESSABLE_ENTITY]
-
+ assert response.status_code in [
+ status.HTTP_200_OK,
+ status.HTTP_422_UNPROCESSABLE_ENTITY,
+ ]
+
def test_malformed_json(self, api_client):
response = api_client.post(
"/v1/responses",
data="not json",
- headers={"Content-Type": "application/json"}
+ headers={"Content-Type": "application/json"},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
-
+
def test_extremely_long_input(self, api_client, sample_request_data):
# Test with very long input
sample_request_data["input"] = "x" * 100000
@@ -138,55 +140,51 @@ def test_extremely_long_input(self, api_client, sample_request_data):
class TestToolIntegration:
-
+
def test_browser_search_tool(self, api_client, sample_request_data):
- sample_request_data["tools"] = [
- {
- "type": "browser_search"
- }
- ]
+ sample_request_data["tools"] = [{"type": "browser_search"}]
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK
-
+
def test_function_tool_integration(self, api_client, sample_request_data):
sample_request_data["tools"] = [
{
"type": "function",
"name": "test_function",
"parameters": {"type": "object", "properties": {}},
- "description": "Test function"
+ "description": "Test function",
}
]
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK
-
+
def test_multiple_tools(self, api_client, sample_request_data):
sample_request_data["tools"] = [
- {
- "type": "browser_search"
- },
+ {"type": "browser_search"},
{
"type": "function",
"name": "test_function",
"parameters": {"type": "object", "properties": {}},
- "description": "Test function"
- }
+ "description": "Test function",
+ },
]
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK
class TestPerformance:
-
- def test_response_time_under_threshold(self, api_client, sample_request_data, performance_timer):
+
+ def test_response_time_under_threshold(
+ self, api_client, sample_request_data, performance_timer
+ ):
performance_timer.start()
response = api_client.post("/v1/responses", json=sample_request_data)
elapsed = performance_timer.stop()
-
+
assert response.status_code == status.HTTP_200_OK
# Response should be reasonably fast for mock inference
assert elapsed < 5.0 # 5 seconds threshold
-
+
def test_multiple_sequential_requests(self, api_client, sample_request_data):
# Test multiple requests work correctly
for i in range(3):
@@ -197,12 +195,12 @@ def test_multiple_sequential_requests(self, api_client, sample_request_data):
class TestUsageTracking:
-
+
def test_usage_object_structure(self, api_client, sample_request_data):
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()
-
+
assert "usage" in data
usage = data["usage"]
assert "input_tokens" in usage
@@ -210,21 +208,21 @@ def test_usage_object_structure(self, api_client, sample_request_data):
assert "total_tokens" in usage
# reasoning_tokens may not always be present
# assert "reasoning_tokens" in usage
-
+
# Basic validation
assert usage["input_tokens"] >= 0
assert usage["output_tokens"] >= 0
assert usage["total_tokens"] == usage["input_tokens"] + usage["output_tokens"]
-
+
def test_usage_increases_with_longer_input(self, api_client, sample_request_data):
# Short input
response1 = api_client.post("/v1/responses", json=sample_request_data)
usage1 = response1.json()["usage"]
-
+
# Longer input
sample_request_data["input"] = sample_request_data["input"] * 10
response2 = api_client.post("/v1/responses", json=sample_request_data)
usage2 = response2.json()["usage"]
-
+
# Longer input should use more tokens
- assert usage2["input_tokens"] > usage1["input_tokens"]
\ No newline at end of file
+ assert usage2["input_tokens"] > usage1["input_tokens"]
diff --git a/tests/test_responses_api.py b/tests/test_responses_api.py
index c3caec3..33de68f 100644
--- a/tests/test_responses_api.py
+++ b/tests/test_responses_api.py
@@ -2,10 +2,7 @@
import pytest
from fastapi.testclient import TestClient
-from openai_harmony import (
- HarmonyEncodingName,
- load_harmony_encoding,
-)
+from openai_harmony import HarmonyEncodingName, load_harmony_encoding
from gpt_oss.responses_api.api_server import create_api_server