diff --git a/README.md b/README.md index 1796d38..f1919c9 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ outputs = pipe( messages, max_new_tokens=256, ) -print(outputs[0]["generated_text"][-1]) +print(outputs[0]["generated_text"][-1]['content']) ``` [Learn more about how to use gpt-oss with Transformers.](https://cookbook.openai.com/articles/gpt-oss/run-transformers) @@ -253,7 +253,7 @@ hf download openai/gpt-oss-20b --include "metal/*" --local-dir gpt-oss-20b/metal To test it you can run: ```shell -python gpt_oss/metal/examples/generate.py gpt-oss-20b/metal/model.bin -p "why did the chicken cross the road?" +python gpt_oss/metal/examples/generate.py gpt-oss-20b/metal/metal/model.bin -p "why did the chicken cross the road?" ``` ## Harmony format & tools diff --git a/_build/gpt_oss_build_backend/backend.py b/_build/gpt_oss_build_backend/backend.py index 5cd76bd..d0077ef 100644 --- a/_build/gpt_oss_build_backend/backend.py +++ b/_build/gpt_oss_build_backend/backend.py @@ -29,112 +29,550 @@ `build-backend = "gpt_oss_build_backend.backend"` in pyproject.toml. """ import os +import sys from importlib import import_module -from typing import Any, Mapping, Sequence - - -TRUE_VALUES = {"1", "true", "TRUE", "on", "ON", "yes", "YES"} +from typing import Any, Callable, List, Mapping, Optional, Protocol, Sequence, Union + + +# Configuration constants +ENV_VAR_METAL_BUILD = "GPTOSS_BUILD_METAL" +TRUE_VALUES = {"1", "true", "on", "yes", "y", "t"} + +# Build requirements +METAL_BUILD_REQUIREMENTS = [ + "scikit-build-core>=0.10", + "pybind11>=2.12", + "cmake>=3.26", + "ninja", +] + +SETUPTOOLS_BUILD_REQUIREMENTS: List[str] = [] + + +class BuildBackendProtocol(Protocol): + """Protocol defining the interface for build backends.""" + + def build_wheel( + self, + wheel_directory: str, + config_settings: Optional[Mapping[str, Any]] = None, + metadata_directory: Optional[str] = None, + ) -> str: + """Build a wheel and return its filename.""" + ... + + def build_sdist( + self, + sdist_directory: str, + config_settings: Optional[Mapping[str, Any]] = None, + ) -> str: + """Build a source distribution and return its filename.""" + ... + + def get_requires_for_build_wheel( + self, + config_settings: Optional[Mapping[str, Any]] = None, + ) -> List[str]: + """Get requirements for building a wheel.""" + ... + + def get_requires_for_build_sdist( + self, + config_settings: Optional[Mapping[str, Any]] = None, + ) -> List[str]: + """Get requirements for building a source distribution.""" + ... + + +class BuildError(Exception): + """Custom exception for build-related errors.""" + pass + + +class ConfigurationError(Exception): + """Exception raised for configuration-related errors.""" + pass + + +def _log_info(message: str) -> None: + """ + Simple logging function that prints to stderr with prefix. + + Args: + message: The message to log. + """ + print(f"[gpt-oss-build] {message}", file=sys.stderr) + + +def _log_error(message: str) -> None: + """ + Log error messages to stderr with prefix. + + Args: + message: The error message to log. + """ + print(f"[gpt-oss-build] ERROR: {message}", file=sys.stderr) + + +def _validate_directory(directory: str, operation: str) -> None: + """ + Validate that a directory exists and is writable. + + Args: + directory: Path to the directory to validate. + operation: Description of the operation for error messages. + + Raises: + ConfigurationError: If directory validation fails. + """ + if not directory: + raise ConfigurationError(f"Directory for {operation} cannot be empty") + + directory_path = os.path.abspath(directory) + + # Create directory if it doesn't exist + try: + os.makedirs(directory_path, exist_ok=True) + except OSError as e: + raise ConfigurationError(f"Cannot create directory {directory_path} for {operation}: {e}") + + # Check if directory is writable + if not os.access(directory_path, os.W_OK): + raise ConfigurationError(f"Directory {directory_path} is not writable for {operation}") + + +def _parse_environment_variable(var_name: str, default: str = "") -> str: + """ + Parse and normalize an environment variable value. + + Args: + var_name: Name of the environment variable. + default: Default value if the variable is not set. + + Returns: + Normalized environment variable value. + """ + value = os.environ.get(var_name, default).strip().lower() + _log_info(f"Environment variable {var_name}='{os.environ.get(var_name, default)}' -> normalized: '{value}'") + return value def _use_metal_backend() -> bool: - return str(os.environ.get("GPTOSS_BUILD_METAL", "")).strip() in TRUE_VALUES - - -def _setuptools_backend(): - from setuptools import build_meta as _bm # type: ignore - - return _bm - - -def _scikit_build_backend(): - return import_module("scikit_build_core.build") - - -def _backend(): - return _scikit_build_backend() if _use_metal_backend() else _setuptools_backend() + """ + Check if Metal backend should be used based on GPTOSS_BUILD_METAL environment variable. + + This function performs case-insensitive, whitespace-tolerant parsing of the + environment variable and logs the decision for transparency. + + Returns: + bool: True if Metal backend should be used, False otherwise. + + Raises: + ConfigurationError: If there's an issue with configuration validation. + """ + try: + env_value = _parse_environment_variable(ENV_VAR_METAL_BUILD) + use_metal = env_value in TRUE_VALUES + + if env_value: + backend_type = "Metal (scikit-build-core)" if use_metal else "setuptools (invalid value)" + _log_info(f"Backend selection: {backend_type}") + else: + _log_info("Backend selection: setuptools (default - no environment variable set)") + + return use_metal + except Exception as e: + _log_error(f"Failed to determine backend from environment: {e}") + _log_info("Falling back to setuptools backend") + return False + + +def _setuptools_backend() -> BuildBackendProtocol: + """ + Get the setuptools build backend. + + Returns: + BuildBackendProtocol: The setuptools build backend instance. + + Raises: + BuildError: If setuptools backend cannot be imported. + """ + try: + from setuptools import build_meta as _bm # type: ignore + _log_info("Successfully imported setuptools build backend") + return _bm + except ImportError as e: + _log_error(f"Failed to import setuptools build backend: {e}") + raise BuildError("setuptools is required but not available") from e + + +def _scikit_build_backend() -> BuildBackendProtocol: + """ + Get the scikit-build-core build backend. + + Returns: + BuildBackendProtocol: The scikit-build-core build backend instance. + + Raises: + BuildError: If scikit-build-core backend cannot be imported. + """ + try: + backend = import_module("scikit_build_core.build") + _log_info("Successfully imported scikit-build-core backend") + return backend + except ImportError as e: + _log_error(f"Failed to import scikit-build-core: {e}") + _log_error("Install Metal build dependencies with: pip install 'scikit-build-core>=0.10' cmake ninja") + raise BuildError("scikit-build-core is required for Metal backend but not available") from e + + +def _backend() -> BuildBackendProtocol: + """ + Get the appropriate build backend based on configuration. + + This function selects between setuptools and scikit-build-core backends + based on the environment configuration, with comprehensive error handling. + + Returns: + BuildBackendProtocol: The selected build backend instance. + + Raises: + BuildError: If the selected backend cannot be loaded. + ConfigurationError: If there's a configuration issue. + """ + try: + if _use_metal_backend(): + _log_info("Configuring scikit-build-core backend for Metal extension") + return _scikit_build_backend() + else: + _log_info("Configuring setuptools backend for pure Python wheel") + return _setuptools_backend() + except (BuildError, ConfigurationError): + # Re-raise known exceptions + raise + except Exception as e: + _log_error(f"Unexpected error while selecting build backend: {e}") + raise BuildError(f"Failed to configure build backend: {e}") from e + + +def _safe_getattr(obj: Any, name: str, default: Optional[Callable] = None) -> Optional[Callable]: + """ + Safely get an attribute from an object with logging. + + Args: + obj: The object to get the attribute from. + name: The name of the attribute. + default: Default value if attribute doesn't exist. + + Returns: + The attribute value or default. + """ + attr = getattr(obj, name, default) + if attr is None: + _log_info(f"Backend doesn't implement {name}") + else: + _log_info(f"Backend supports {name}") + return attr # Required PEP 517 hooks def build_wheel( wheel_directory: str, - config_settings: Mapping[str, Any] | None = None, - metadata_directory: str | None = None, + config_settings: Optional[Mapping[str, Any]] = None, + metadata_directory: Optional[str] = None, ) -> str: - return _backend().build_wheel(wheel_directory, config_settings, metadata_directory) + """ + Build a wheel in the specified directory. + + Args: + wheel_directory: Directory where the wheel should be built. + config_settings: Optional configuration settings. + metadata_directory: Optional metadata directory. + + Returns: + str: The filename of the built wheel. + + Raises: + BuildError: If the wheel build fails. + ConfigurationError: If directory validation fails. + """ + try: + _validate_directory(wheel_directory, "wheel building") + _log_info(f"Building wheel in directory: {wheel_directory}") + + backend = _backend() + result = backend.build_wheel(wheel_directory, config_settings, metadata_directory) + + _log_info(f"Successfully built wheel: {result}") + return result + except (BuildError, ConfigurationError): + raise + except Exception as e: + _log_error(f"Wheel build failed: {e}") + raise BuildError(f"Failed to build wheel: {e}") from e def build_sdist( - sdist_directory: str, config_settings: Mapping[str, Any] | None = None + sdist_directory: str, + config_settings: Optional[Mapping[str, Any]] = None ) -> str: - return _backend().build_sdist(sdist_directory, config_settings) + """ + Build a source distribution in the specified directory. + + Args: + sdist_directory: Directory where the sdist should be built. + config_settings: Optional configuration settings. + + Returns: + str: The filename of the built source distribution. + + Raises: + BuildError: If the sdist build fails. + ConfigurationError: If directory validation fails. + """ + try: + _validate_directory(sdist_directory, "source distribution building") + _log_info(f"Building sdist in directory: {sdist_directory}") + + backend = _backend() + result = backend.build_sdist(sdist_directory, config_settings) + + _log_info(f"Successfully built sdist: {result}") + return result + except (BuildError, ConfigurationError): + raise + except Exception as e: + _log_error(f"Sdist build failed: {e}") + raise BuildError(f"Failed to build sdist: {e}") from e def prepare_metadata_for_build_wheel( - metadata_directory: str, config_settings: Mapping[str, Any] | None = None + metadata_directory: str, + config_settings: Optional[Mapping[str, Any]] = None ) -> str: - # Fallback if backend doesn't implement it - be = _backend() - fn = getattr(be, "prepare_metadata_for_build_wheel", None) - if fn is None: - # setuptools exposes it; scikit-build-core may not. Defer to building a wheel for metadata. - return _setuptools_backend().prepare_metadata_for_build_wheel( - metadata_directory, config_settings - ) - return fn(metadata_directory, config_settings) + """ + Prepare metadata for wheel building. + + Args: + metadata_directory: Directory where metadata should be prepared. + config_settings: Optional configuration settings. + + Returns: + str: The name of the metadata directory. + + Raises: + BuildError: If metadata preparation fails. + ConfigurationError: If directory validation fails. + """ + try: + _validate_directory(metadata_directory, "metadata preparation") + _log_info(f"Preparing metadata in directory: {metadata_directory}") + + backend = _backend() + fn = _safe_getattr(backend, "prepare_metadata_for_build_wheel") + + if fn is None: + # Fallback to setuptools if current backend doesn't support it + _log_info("Using setuptools fallback for metadata preparation") + setuptools_backend = _setuptools_backend() + result = setuptools_backend.prepare_metadata_for_build_wheel( + metadata_directory, config_settings + ) + else: + result = fn(metadata_directory, config_settings) + + _log_info(f"Successfully prepared metadata: {result}") + return result + except (BuildError, ConfigurationError): + raise + except Exception as e: + _log_error(f"Metadata preparation failed: {e}") + raise BuildError(f"Failed to prepare metadata: {e}") from e # Optional hooks def build_editable( - editable_directory: str, config_settings: Mapping[str, Any] | None = None, metadata_directory: str | None = None + main + wheel_directory: str, + config_settings: Optional[Mapping[str, Any]] = None, + metadata_directory: Optional[str] = None, + ) -> str: - be = _backend() - fn = getattr(be, "build_editable", None) - if fn is None: - # setuptools implements build_editable; if not available, raise the standard error - raise RuntimeError("Editable installs not supported by the selected backend") - return fn(editable_directory, config_settings) + """ + Build an editable wheel in the specified directory. + + Args: + wheel_directory: Directory where the editable wheel should be built. + config_settings: Optional configuration settings. + metadata_directory: Optional metadata directory. + + Returns: + str: The filename of the built editable wheel. + + Raises: + BuildError: If the editable build fails or is not supported. + ConfigurationError: If directory validation fails. + """ + try: + _validate_directory(wheel_directory, "editable wheel building") + _log_info(f"Building editable install in directory: {wheel_directory}") + + backend = _backend() + fn = _safe_getattr(backend, "build_editable") + + if fn is None: + raise BuildError("Editable installs not supported by the selected backend") + + result = fn(wheel_directory, config_settings, metadata_directory) + _log_info(f"Successfully built editable wheel: {result}") + return result + except (BuildError, ConfigurationError): + raise + except Exception as e: + _log_error(f"Editable build failed: {e}") + raise BuildError(f"Failed to build editable wheel: {e}") from e def get_requires_for_build_wheel( - config_settings: Mapping[str, Any] | None = None, + config_settings: Optional[Mapping[str, Any]] = None, ) -> Sequence[str]: - if _use_metal_backend(): - # Add dynamic build requirements only when building the Metal backend - return [ - "scikit-build-core>=0.10", - "pybind11>=2.12", - "cmake>=3.26", - "ninja", - ] - # setuptools usually returns [] - return list(_setuptools_backend().get_requires_for_build_wheel(config_settings)) + """ + Get the build requirements for building a wheel. + + Args: + config_settings: Optional configuration settings. + + Returns: + Sequence[str]: List of build requirements. + + Raises: + ConfigurationError: If there's a configuration issue. + """ + try: + if _use_metal_backend(): + requirements = METAL_BUILD_REQUIREMENTS.copy() + _log_info(f"Metal backend requires: {', '.join(requirements)}") + return requirements + else: + # Get requirements from setuptools backend + setuptools_backend = _setuptools_backend() + fn = _safe_getattr(setuptools_backend, "get_requires_for_build_wheel") + + if fn is not None: + requirements = list(fn(config_settings)) + else: + requirements = SETUPTOOLS_BUILD_REQUIREMENTS.copy() + + _log_info(f"Setuptools backend requires: {', '.join(requirements) if requirements else 'no additional packages'}") + return requirements + except Exception as e: + _log_error(f"Failed to get build requirements: {e}") + raise ConfigurationError(f"Unable to determine build requirements: {e}") from e def get_requires_for_build_sdist( - config_settings: Mapping[str, Any] | None = None, + config_settings: Optional[Mapping[str, Any]] = None, ) -> Sequence[str]: - # No special requirements for SDist - be = _backend() - fn = getattr(be, "get_requires_for_build_sdist", None) - if fn is None: - return [] - return list(fn(config_settings)) + """ + Get the build requirements for building a source distribution. + + Args: + config_settings: Optional configuration settings. + + Returns: + Sequence[str]: List of build requirements. + + Raises: + ConfigurationError: If there's a configuration issue. + """ + try: + backend = _backend() + fn = _safe_getattr(backend, "get_requires_for_build_sdist") + + if fn is None: + requirements: List[str] = [] + else: + requirements = list(fn(config_settings)) + + _log_info(f"SDist requires: {', '.join(requirements) if requirements else 'no additional packages'}") + return requirements + except Exception as e: + _log_error(f"Failed to get sdist requirements: {e}") + raise ConfigurationError(f"Unable to determine sdist requirements: {e}") from e def get_requires_for_build_editable( - config_settings: Mapping[str, Any] | None = None, + config_settings: Optional[Mapping[str, Any]] = None, +) -> Sequence[str]: + """ + Get the build requirements for building an editable install. + + Args: + config_settings: Optional configuration settings. + + Returns: + Sequence[str]: List of build requirements. + + Raises: + ConfigurationError: If there's a configuration issue. + """ + try: + if _use_metal_backend(): + requirements = METAL_BUILD_REQUIREMENTS.copy() + _log_info(f"Editable Metal backend requires: {', '.join(requirements)}") + return requirements + else: + setuptools_backend = _setuptools_backend() + fn = _safe_getattr(setuptools_backend, "get_requires_for_build_editable") + + if fn is None: + requirements: List[str] = [] + else: + requirements = list(fn(config_settings)) + + _log_info(f"Editable setuptools backend requires: {', '.join(requirements) if requirements else 'no additional packages'}") + return requirements + except Exception as e: + _log_error(f"Failed to get editable build requirements: {e}") + raise ConfigurationError(f"Unable to determine editable build requirements: {e}") from e + + +# Future expansion hooks (currently unused but available for extension) + +def get_requires_for_build_meta( + config_settings: Optional[Mapping[str, Any]] = None, ) -> Sequence[str]: - if _use_metal_backend(): - return [ - "scikit-build-core>=0.10", - "pybind11>=2.12", - "cmake>=3.26", - "ninja", - ] - be = _setuptools_backend() - fn = getattr(be, "get_requires_for_build_editable", None) - if fn is None: - return [] - return list(fn(config_settings)) \ No newline at end of file + """ + Get requirements for building metadata (future PEP extension). + + Args: + config_settings: Optional configuration settings. + + Returns: + Sequence[str]: List of requirements (currently empty). + """ + _log_info("get_requires_for_build_meta called (future extension)") + return [] + + +def build_meta( + meta_directory: str, + config_settings: Optional[Mapping[str, Any]] = None, +) -> str: + """ + Build metadata only (future PEP extension). + + Args: + meta_directory: Directory for metadata. + config_settings: Optional configuration settings. + + Returns: + str: Metadata filename. + + Raises: + BuildError: Always, as this is not yet implemented. + """ + _log_info("build_meta called (future extension)") + raise BuildError("build_meta is not yet implemented") \ No newline at end of file diff --git a/gpt_oss/chat.py b/gpt_oss/chat.py index 5e40079..8a749cf 100644 --- a/gpt_oss/chat.py +++ b/gpt_oss/chat.py @@ -2,9 +2,9 @@ Harmony chat with tools """ -import atexit import argparse import asyncio +import atexit import datetime import os from pathlib import Path @@ -14,14 +14,8 @@ except ImportError: import readline -import torch import termcolor - -from gpt_oss.tools import apply_patch -from gpt_oss.tools.simple_browser import SimpleBrowserTool -from gpt_oss.tools.simple_browser.backend import ExaBackend -from gpt_oss.tools.python_docker.docker_tool import PythonTool - +import torch from openai_harmony import ( Author, Conversation, @@ -38,6 +32,10 @@ load_harmony_encoding, ) +from gpt_oss.tools import apply_patch +from gpt_oss.tools.python_docker.docker_tool import PythonTool +from gpt_oss.tools.simple_browser import SimpleBrowserTool +from gpt_oss.tools.simple_browser.backend import ExaBackend REASONING_EFFORT = { "high": ReasoningEffort.HIGH, @@ -49,7 +47,10 @@ def get_user_input(): rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 if rank == 0: - user_input = input() + try: + user_input = input() + except EOFError: + return "/exit" else: user_input = "" user_input_list = [user_input] @@ -59,35 +60,42 @@ def get_user_input(): def main(args): + # --- Backend & generator init match args.backend: case "triton": - from gpt_oss.triton.model import TokenGenerator as TritonGenerator from gpt_oss.torch.utils import init_distributed + from gpt_oss.triton.model import TokenGenerator as TritonGenerator + device = init_distributed() generator = TritonGenerator(args.checkpoint, args.context, device) case "torch": from gpt_oss.torch.model import TokenGenerator as TorchGenerator from gpt_oss.torch.utils import init_distributed + device = init_distributed() generator = TorchGenerator(args.checkpoint, device) case "vllm": from gpt_oss.vllm.token_generator import TokenGenerator as VLLMGenerator - generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=2) + + generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=args.tp_size) case _: raise ValueError(f"Invalid backend: {args.backend}") encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + # --- System message system_message_content = ( SystemContent.new() .with_reasoning_effort(REASONING_EFFORT[args.reasoning_effort]) .with_conversation_start_date(datetime.datetime.now().strftime("%Y-%m-%d")) ) + # Tools + browser_tool = None + python_tool = None + if args.browser: - backend = ExaBackend( - source="web", - ) + backend = ExaBackend(source="web") browser_tool = SimpleBrowserTool(backend=backend) system_message_content = system_message_content.with_tools(browser_tool.tool_config) @@ -98,34 +106,35 @@ def main(args): system_message = Message.from_role_and_content(Role.SYSTEM, system_message_content) messages = [system_message] + # --- Developer message (resolved) + developer_message_content = None if args.apply_patch: apply_patch_instructions = Path(apply_patch.__file__).parent / "apply_patch.md" - developer_message = "" - if args.developer_message: - developer_message = args.developer_message + "\n" + developer_message = (args.developer_message + "\n") if args.developer_message else "" developer_message += apply_patch_instructions.read_text() developer_message_content = ( DeveloperContent.new() .with_instructions(developer_message) - .with_function_tools([ - ToolDescription.new( - "apply_patch", - "Patch a file", - parameters={ - "type": "string", - "description": "Formatted patch code", - "default": "*** Begin Patch\n*** End Patch\n", - } - ), - ]) + .with_function_tools( + [ + ToolDescription.new( + "apply_patch", + "Patch a file", + parameters={ + "type": "string", + "description": "Formatted patch code", + "default": "*** Begin Patch\n*** End Patch\n", + }, + ), + ] + ) ) messages.append(Message.from_role_and_content(Role.DEVELOPER, developer_message_content)) elif args.developer_message: developer_message_content = DeveloperContent.new().with_instructions(args.developer_message) messages.append(Message.from_role_and_content(Role.DEVELOPER, developer_message_content)) - else: - developer_message_content = None + # --- Raw vs pretty headers if args.raw: conversation = Conversation.from_messages(messages) tokens = encoding.render_conversation(conversation) @@ -135,7 +144,6 @@ def main(args): user_message_start = encoding.decode(empty_user_message_tokens[:-1]) user_message_end = encoding.decode(empty_user_message_tokens[-1:]) else: - # System message print(termcolor.colored("System Message:", "cyan"), flush=True) print(termcolor.colored("Model Identity:", "cyan"), system_message_content.model_identity, flush=True) print(termcolor.colored("Reasoning Effort:", "cyan"), system_message_content.reasoning_effort, flush=True) @@ -148,138 +156,159 @@ def main(args): print(termcolor.colored("Developer Message:", "yellow"), flush=True) print(developer_message_content.instructions, flush=True) - # Print the system message and the user message start + # --- REPL MESSAGE_PADDING = 12 - while True: - last_message = messages[-1] - if last_message.recipient is None: - if args.raw: - print(user_message_start, end="", flush=True) - user_message = get_user_input() - print(user_message_end, flush=True, end="") + try: + while True: + last_message = messages[-1] + if last_message.recipient is None: + # Get user input + if args.raw: + print(user_message_start, end="", flush=True) + user_text = get_user_input() + print(user_message_end, flush=True, end="") + else: + print(termcolor.colored("User:".ljust(MESSAGE_PADDING), "red"), flush=True) + user_text = get_user_input() + + if user_text.strip().lower() in {"/exit", "exit", ":q", "quit"}: + print(termcolor.colored("Bye!", "cyan")) + break + + user_message = Message.from_role_and_content(Role.USER, user_text) + messages.append(user_message) else: - print(termcolor.colored("User:".ljust(MESSAGE_PADDING), "red"), flush=True) - user_message = get_user_input() - user_message = Message.from_role_and_content(Role.USER, user_message) - messages.append(user_message) - else: - # Tool or function call - if last_message.recipient.startswith("browser."): - assert args.browser, "Browser tool is not enabled" - tool_name = "Search" - async def run_tool(): - results = [] - async for msg in browser_tool.process(last_message): - results.append(msg) - return results - - result = asyncio.run(run_tool()) - messages += result - elif last_message.recipient.startswith("python"): - assert args.python, "Python tool is not enabled" - tool_name = "Python" - async def run_tool(): - results = [] - async for msg in python_tool.process(last_message): - results.append(msg) - return results - - result = asyncio.run(run_tool()) - messages += result - elif last_message.recipient == "functions.apply_patch": - assert args.apply_patch, "Apply patch tool is not enabled" - tool_name = "Apply Patch" - text = last_message.content[0].text - tool_output = None - - if text.startswith("{"): - # this is json, try to extract the patch from it - import json - try: - some_dict = json.loads(text) - _, text = some_dict.popitem() - except Exception as e: - tool_output = f"Error parsing JSON: {e}" - - if tool_output is None: - try: - tool_output = apply_patch.apply_patch(text) - except Exception as e: - tool_output = f"Error applying patch: {e}" - - message = ( - Message( - author=Author.new(Role.TOOL, last_message.recipient), - content=[TextContent(text=tool_output)] + # Tool/function dispatch + if last_message.recipient.startswith("browser."): + assert args.browser, "Browser tool is not enabled" + tool_name = "Search" + + async def run_tool(): + results = [] + async for msg in browser_tool.process(last_message): + results.append(msg) + return results + + result = asyncio.run(run_tool()) + messages += result + elif last_message.recipient.startswith("python"): + assert args.python, "Python tool is not enabled" + tool_name = "Python" + + async def run_tool(): + results = [] + async for msg in python_tool.process(last_message): + results.append(msg) + return results + + result = asyncio.run(run_tool()) + messages += result + elif last_message.recipient == "functions.apply_patch": + assert args.apply_patch, "Apply patch tool is not enabled" + tool_name = "Apply Patch" + text = last_message.content[0].text + tool_output = None + + if text.startswith("{"): + import json + try: + some_dict = json.loads(text) + _, text = some_dict.popitem() + except Exception as e: + tool_output = f"Error parsing JSON: {e}" + + if tool_output is None: + try: + tool_output = apply_patch.apply_patch(text) + except Exception as e: + tool_output = f"Error applying patch: {e}" + + message = ( + Message( + author=Author.new(Role.TOOL, last_message.recipient), + content=[TextContent(text=tool_output)], + ) + .with_recipient("assistant") ) - .with_recipient("assistant") - ) - if last_message.channel: - message = message.with_channel(last_message.channel) + if last_message.channel: + message = message.with_channel(last_message.channel) - result = [message] - messages += result - else: - raise ValueError(f"Unknown tool or function call: {last_message.recipient}") - # Print the tool or function call result - if args.raw: - rendered_result = encoding.render_conversation(Conversation.from_messages(result)) - print(encoding.decode(rendered_result), flush=True, end="") - else: - print(termcolor.colored(f"{tool_name} output:".ljust(MESSAGE_PADDING), "magenta"), flush=True) - if tool_name == "Search" and not args.show_browser_results: - print("[Search results fed to the model]") + result = [message] + messages += result else: - print(result[0].content[0].text) + raise ValueError(f"Unknown tool or function call: {last_message.recipient}") - conversation = Conversation.from_messages(messages) - tokens = encoding.render_conversation_for_completion( - conversation, Role.ASSISTANT - ) - - if args.raw: - # Print the last two tokens, which are the start of the assistant message - print(encoding.decode(tokens[-2:]), flush=True, end="") + # Print tool output + if args.raw: + rendered_result = encoding.render_conversation(Conversation.from_messages(result)) + print(encoding.decode(rendered_result), flush=True, end="") + else: + print(termcolor.colored(f"{tool_name} output:".ljust(MESSAGE_PADDING), "magenta"), flush=True) + if tool_name == "Search" and not args.show_browser_results: + print("[Search results fed to the model]") + else: + # Some tools may return multiple messages; show first non-empty safely + first_text = "" + for m in result: + if m.content and m.content[0].text: + first_text = m.content[0].text + break + print(first_text) + + # --- Model completion stream + conversation = Conversation.from_messages(messages) + tokens = encoding.render_conversation_for_completion(conversation, Role.ASSISTANT) - parser = StreamableParser(encoding, role=Role.ASSISTANT) - field_created = False - current_output_text = "" - output_text_delta_buffer = "" - for predicted_token in generator.generate(tokens, encoding.stop_tokens_for_assistant_actions()): - parser.process(predicted_token) if args.raw: - print(encoding.decode([predicted_token]), end="", flush=True) - continue - - if parser.state == StreamState.EXPECT_START: - print("") # new line - field_created = False - - if not parser.last_content_delta: - continue - - if not field_created: - field_created = True - if parser.current_channel == "final": - print(termcolor.colored("Assistant:", "green"), flush=True) - elif parser.current_recipient is not None: - print(termcolor.colored(f"Tool call to {parser.current_recipient}:", "cyan"), flush=True) - else: - print(termcolor.colored("CoT:", "yellow"), flush=True) + # Print the last two tokens (start of assistant msg) + print(encoding.decode(tokens[-2:]), flush=True, end="") + + parser = StreamableParser(encoding, role=Role.ASSISTANT) + field_created = False + current_output_text = "" + output_text_delta_buffer = "" + + for predicted_token in generator.generate(tokens, encoding.stop_tokens_for_assistant_actions()): + parser.process(predicted_token) + if args.raw: + print(encoding.decode([predicted_token]), end="", flush=True) + continue + + if parser.state == StreamState.EXPECT_START: + print("") # new line + field_created = False + + if not parser.last_content_delta: + continue + + if not field_created: + field_created = True + if parser.current_channel == "final": + print(termcolor.colored("Assistant:", "green"), flush=True) + elif parser.current_recipient is not None: + print(termcolor.colored(f"Tool call to {parser.current_recipient}:", "cyan"), flush=True) + else: + print(termcolor.colored("CoT:", "yellow"), flush=True) + + should_send_output_text_delta = True + output_text_delta_buffer += parser.last_content_delta + + if args.browser and browser_tool is not None: + updated_output_text, _annotations, has_partial_citations = browser_tool.normalize_citations( + current_output_text + output_text_delta_buffer + ) + output_text_delta_buffer = updated_output_text[len(current_output_text) :] + if has_partial_citations: + should_send_output_text_delta = False - should_send_output_text_delta = True - output_text_delta_buffer += parser.last_content_delta - if args.browser: - updated_output_text, _annotations, has_partial_citations = browser_tool.normalize_citations(current_output_text + output_text_delta_buffer) - output_text_delta_buffer = updated_output_text[len(current_output_text):] - if has_partial_citations: - should_send_output_text_delta = False - if should_send_output_text_delta: - print(output_text_delta_buffer, end="", flush=True) - current_output_text += output_text_delta_buffer - output_text_delta_buffer = "" + if should_send_output_text_delta: + print(output_text_delta_buffer, end="", flush=True) + current_output_text += output_text_delta_buffer + output_text_delta_buffer = "" - messages += parser.messages + messages += parser.messages + except KeyboardInterrupt: + print("\n" + termcolor.colored("Interrupted. Bye!", "cyan")) if __name__ == "__main__": @@ -354,6 +383,12 @@ async def run_tool(): choices=["triton", "torch", "vllm"], help="Inference backend", ) + parser.add_argument( + "--tp-size", + type=int, + default=2, + help="Tensor parallel size for vLLM backend", + ) args = parser.parse_args() if int(os.environ.get("WORLD_SIZE", 1)) == 1: diff --git a/gpt_oss/evals/__main__.py b/gpt_oss/evals/__main__.py index bb34e2c..b89963b 100644 --- a/gpt_oss/evals/__main__.py +++ b/gpt_oss/evals/__main__.py @@ -3,14 +3,12 @@ from datetime import datetime from . import report +from .aime_eval import AIME25Eval from .basic_eval import BasicEval +from .chat_completions_sampler import (OPENAI_SYSTEM_MESSAGE_API, + ChatCompletionsSampler) from .gpqa_eval import GPQAEval -from .aime_eval import AIME25Eval from .healthbench_eval import HealthBenchEval -from .chat_completions_sampler import ( - OPENAI_SYSTEM_MESSAGE_API, - ChatCompletionsSampler, -) from .responses_sampler import ResponsesSampler @@ -62,16 +60,16 @@ def main(): default=1584, help="Number of threads to run.", ) - parser.add_argument( - "--debug", action="store_true", help="Run in debug mode" - ) + parser.add_argument("--debug", action="store_true", help="Run in debug mode") parser.add_argument( "--examples", type=int, help="Number of examples to use (overrides default)" ) args = parser.parse_args() - sampler_cls = ResponsesSampler if args.sampler == "responses" else ChatCompletionsSampler + sampler_cls = ( + ResponsesSampler if args.sampler == "responses" else ChatCompletionsSampler + ) models = {} for model_name in args.model.split(","): @@ -192,7 +190,8 @@ def get_evals(eval_name, debug_mode): merge_metrics = [] for eval_model_name, result_filename in mergekey2resultpath.items(): try: - result = json.load(open(result_filename, "r+")) + with open(result_filename, "r") as f: + result = json.load(f) except Exception as e: print(e, result_filename) continue diff --git a/gpt_oss/evals/abcd_grader.py b/gpt_oss/evals/abcd_grader.py index 37088c8..cd2351f 100644 --- a/gpt_oss/evals/abcd_grader.py +++ b/gpt_oss/evals/abcd_grader.py @@ -1,23 +1,22 @@ import re import sys - _PATTERNS = [ # 0)"**Answer:** A" or "*Answers* – B", i.e. markdown‐wrapped "Answer(s)" with an unwrapped letter. re.compile( - r'''(?ix) # case‐insensitive, ignore‐space + r"""(?ix) # case‐insensitive, ignore‐space (?:\*{1,2}|_{1,2}) # leading *…* or _…_ Answer[s]? # Answer or Answers \s*[:\-–]? # optional separator (?:\*{1,2}|_{1,2}) # closing wrapper \s* # optional space ([ABCD])\b # the actual letter - ''', - re.X + """, + re.X, ), - # 0.1) - re.compile(r'''(?ix) # ignore case, allow verbose mode + re.compile( + r"""(?ix) # ignore case, allow verbose mode ^\s* # optional leading whitespace (?:\*{1,2}|_{1,2})? # optional markdown wrapper Answer:? # the word 'answer' with an optional colon @@ -27,54 +26,56 @@ ([ABCD]) # capture the letter (?:\*{1,2}|_{1,2})? # optional markdown wrapper after letter \s* # optional trailing whitespace, end of line - ''', re.MULTILINE), - + """, + re.MULTILINE, + ), # 1) Answer: (C) or Answers: (B) - re.compile(r'(?ix)\bAnswer[s]?\b\s*[:\-–]?\s*\(\s*([ABCD])\s*\)'), - + re.compile(r"(?ix)\bAnswer[s]?\b\s*[:\-–]?\s*\(\s*([ABCD])\s*\)"), # 2) Answer: C or Answers – D - re.compile(r'(?ix)\bAnswer[s]?\b\s*[:\-–]?\s*([ABCD])\b'), - + re.compile(r"(?ix)\bAnswer[s]?\b\s*[:\-–]?\s*([ABCD])\b"), # 3) Option B or Choice: C - re.compile(r'(?ix)\b(?:Option|Choice)\b\s*[:\-–]?\s*([ABCD])\b'), - + re.compile(r"(?ix)\b(?:Option|Choice)\b\s*[:\-–]?\s*([ABCD])\b"), # 7) LaTeX \boxed{...A...}, catches both \boxed{A} and # \boxed{\text{A } 2.08\times10^{-6}\,\mathrm{m}} etc. - re.compile(r'(?x)\\boxed\{[^}]*?([ABCD])[^}]*\}', re.MULTILINE), - + re.compile(r"(?x)\\boxed\{[^}]*?([ABCD])[^}]*\}", re.MULTILINE), # 7.5) LaTeX \boxed{\textbf{...C...}} - re.compile(r'(?x)\\boxed\{[^}]*?\\textbf\{[^}]*?([ABCD])[^}]*\}[^}]*\}', re.MULTILINE), - + re.compile( + r"(?x)\\boxed\{[^}]*?\\textbf\{[^}]*?([ABCD])[^}]*\}[^}]*\}", re.MULTILINE + ), # 7.51) LaTeX \boxed{\text{...C...}} - re.compile(r'(?x)\\boxed\{[^}]*?\\text\{[^}]*?([ABCD])[^}]*\}[^}]*\}', re.MULTILINE), - + re.compile( + r"(?x)\\boxed\{[^}]*?\\text\{[^}]*?([ABCD])[^}]*\}[^}]*\}", re.MULTILINE + ), # 4) bare singletons: (A) [B] - re.compile(r'(?x)(? str | None: m = pat.search(text) if m: letter = m.group(1).upper() - if letter in 'ABCD': + if letter in "ABCD": matches.append((prio, m, letter)) - matches.sort(key=lambda triple: ( - triple[0], - len(triple[1].group(0)) - )) + matches.sort(key=lambda triple: (triple[0], len(triple[1].group(0)))) for _, match, letter in matches: return letter - return text.removeprefix('**')[:1] + return text.removeprefix("**")[:1] def main(): if len(sys.argv) > 1: # Process files for fn in sys.argv[1:]: - with open(fn, encoding='utf8') as fp: + with open(fn, encoding="utf8") as fp: text = fp.read() ans = extract_abcd(text) print(f"{fn} ➜ {ans!r}") @@ -118,4 +116,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/gpt_oss/evals/aime_eval.py b/gpt_oss/evals/aime_eval.py index c6e9d64..0c5033b 100644 --- a/gpt_oss/evals/aime_eval.py +++ b/gpt_oss/evals/aime_eval.py @@ -1,64 +1,82 @@ """ AIME 2025: https://huggingface.co/datasets/opencompass/AIME2025 """ + import random import re + import pandas -from . import report +from . import report from .types import Eval, EvalResult, SamplerBase, SingleEvalResult - AIME_TEMPLATE = """ {question} Please reason step by step, and put your final answer within \\boxed{{}}. """ + def format_aime_question(row): return AIME_TEMPLATE.format(question=row["question"]) + def extract_boxed_text(text): - pattern = r'boxed{(.*?)}|framebox{(.*?)}' + pattern = r"boxed{(.*?)}|framebox{(.*?)}" matches = re.findall(pattern, text, re.DOTALL) if matches: for match in matches[::-1]: for group in match: if group != "": - return group.split(',')[-1].strip() - pattern = r'\d+' # get the last integer if no pattern found + return group.split(",")[-1].strip() + pattern = r"\d+" # get the last integer if no pattern found matches = re.findall(pattern, text, re.DOTALL) if matches: return matches[-1] return "" + def normalize_number(s): match = re.match(r"\d+", s) # match digits from the start if not match: return None return match.group(0) + class AIME25Eval(Eval): def __init__( self, n_repeats: int = 4, - num_examples: int | None = None, # restrict to a subset of the data for debugging + num_examples: ( + int | None + ) = None, # restrict to a subset of the data for debugging n_threads: int = 1, ): path1 = f"https://huggingface.co/datasets/opencompass/AIME2025/raw/main/aime2025-I.jsonl" df1 = pandas.read_json(path1, lines=True) path2 = f"https://huggingface.co/datasets/opencompass/AIME2025/raw/main/aime2025-II.jsonl" df2 = pandas.read_json(path2, lines=True) - examples = [row.to_dict() for _, row in df1.iterrows()] + [row.to_dict() for _, row in df2.iterrows()] - examples = [{ - "question": row["question"], - "answer": normalize_number(row["answer"]) if isinstance(row["answer"], str) else row["answer"], - } for row in examples] + examples = [row.to_dict() for _, row in df1.iterrows()] + [ + row.to_dict() for _, row in df2.iterrows() + ] + examples = [ + { + "question": row["question"], + "answer": ( + normalize_number(row["answer"]) + if isinstance(row["answer"], str) + else row["answer"] + ), + } + for row in examples + ] rng = random.Random(0) if num_examples: assert n_repeats == 1, "n_repeats only supported for num_examples = None" examples = rng.sample(examples, num_examples) examples = examples * n_repeats - examples = [example | {"permutation": rng.sample(range(4), 4)} for example in examples] + examples = [ + example | {"permutation": rng.sample(range(4), 4)} for example in examples + ] self.examples = examples self.n_repeats = n_repeats self.n_threads = n_threads @@ -66,16 +84,16 @@ def __init__( def __call__(self, sampler: SamplerBase) -> EvalResult: def fn(row: dict): prompt_messages = [ - sampler._pack_message( - content=format_aime_question(row), role="user" - ) + sampler._pack_message(content=format_aime_question(row), role="user") ] sampler_response = sampler(prompt_messages) response_text = sampler_response.response_text - actual_queried_prompt_messages = sampler_response.actual_queried_message_list + actual_queried_prompt_messages = ( + sampler_response.actual_queried_message_list + ) extracted_answer = extract_boxed_text(response_text) correct_answer = int(row["answer"]) - try: # All AIME answers are integers, so we convert the extracted answer to an integer + try: # All AIME answers are integers, so we convert the extracted answer to an integer extracted_answer = int(extracted_answer) except (ValueError, TypeError): extracted_answer = None @@ -87,11 +105,17 @@ def fn(row: dict): correct_answer=correct_answer, extracted_answer=extracted_answer, ) - convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")] + convo = actual_queried_prompt_messages + [ + dict(content=response_text, role="assistant") + ] return SingleEvalResult( - html=html, score=score, convo=convo, metrics={"chars": len(response_text)} + html=html, + score=score, + convo=convo, + metrics={"chars": len(response_text)}, ) - results = report.map_with_progress(fn, self.examples, num_threads=self.n_threads) + results = report.map_with_progress( + fn, self.examples, num_threads=self.n_threads + ) return report.aggregate_results(results) - diff --git a/gpt_oss/evals/basic_eval.py b/gpt_oss/evals/basic_eval.py index 7799530..89537c1 100644 --- a/gpt_oss/evals/basic_eval.py +++ b/gpt_oss/evals/basic_eval.py @@ -1,25 +1,32 @@ """ Basic eval """ -from . import report +from . import report from .types import Eval, EvalResult, SamplerBase, SingleEvalResult + class BasicEval(Eval): - def __init__(self,): - self.examples = [{ - "question": "hi", - "answer": "hi, how can i help?", - }] + def __init__( + self, + ): + self.examples = [ + { + "question": "hi", + "answer": "hi, how can i help?", + } + ] def __call__(self, sampler: SamplerBase) -> EvalResult: def fn(row: dict): - sampler_response = sampler([ - sampler._pack_message(content=row["question"], role="user") - ]) + sampler_response = sampler( + [sampler._pack_message(content=row["question"], role="user")] + ) response_text = sampler_response.response_text extracted_answer = response_text - actual_queried_prompt_messages = sampler_response.actual_queried_message_list + actual_queried_prompt_messages = ( + sampler_response.actual_queried_message_list + ) score = 1.0 if len(extracted_answer) > 0 else 0.0 html = report.jinja_env.from_string(report.HTML_JINJA).render( prompt_messages=actual_queried_prompt_messages, @@ -28,11 +35,15 @@ def fn(row: dict): correct_answer=row["answer"], extracted_answer=extracted_answer, ) - convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")] + convo = actual_queried_prompt_messages + [ + dict(content=response_text, role="assistant") + ] return SingleEvalResult( - html=html, score=score, convo=convo, metrics={"chars": len(response_text)} + html=html, + score=score, + convo=convo, + metrics={"chars": len(response_text)}, ) results = report.map_with_progress(fn, self.examples, num_threads=1) return report.aggregate_results(results) - diff --git a/gpt_oss/evals/chat_completions_sampler.py b/gpt_oss/evals/chat_completions_sampler.py index 29c1a0a..cfd72eb 100644 --- a/gpt_oss/evals/chat_completions_sampler.py +++ b/gpt_oss/evals/chat_completions_sampler.py @@ -2,11 +2,11 @@ from typing import Any import openai +import structlog from openai import OpenAI from .types import MessageList, SamplerBase, SamplerResponse - OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant." OPENAI_SYSTEM_MESSAGE_CHATGPT = ( "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture." @@ -66,7 +66,9 @@ def __call__(self, message_list: MessageList) -> SamplerResponse: choice = response.choices[0] content = choice.message.content if getattr(choice.message, "reasoning", None): - message_list.append(self._pack_message("assistant", choice.message.reasoning)) + message_list.append( + self._pack_message("assistant", choice.message.reasoning) + ) if not content: raise ValueError("OpenAI API returned empty response; retrying") @@ -76,18 +78,35 @@ def __call__(self, message_list: MessageList) -> SamplerResponse: actual_queried_message_list=message_list, ) except openai.BadRequestError as e: - print("Bad Request Error", e) + logger = structlog.get_logger() + logger.error("Bad request error from OpenAI API", error=str(e)) return SamplerResponse( response_text="No response (bad request).", response_metadata={"usage": None}, actual_queried_message_list=message_list, ) - except Exception as e: - exception_backoff = 2 ** trial # exponential back off - print( - f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec", - e, + except ( + openai.RateLimitError, + openai.APITimeoutError, + openai.APIConnectionError, + ) as e: + exception_backoff = min( + 2**trial, 60 + ) # exponential backoff with max 60s + logger = structlog.get_logger() + logger.warning( + "API error, retrying", + trial=trial, + backoff_seconds=exception_backoff, + error=str(e), + error_type=type(e).__name__, ) time.sleep(exception_backoff) trial += 1 + except Exception as e: + logger = structlog.get_logger() + logger.error( + "Unexpected error during API call", error=str(e), trial=trial + ) + raise # Re-raise unexpected errors # unknown error shall throw exception diff --git a/gpt_oss/evals/gpqa_eval.py b/gpt_oss/evals/gpqa_eval.py index 1b12a43..42827f5 100644 --- a/gpt_oss/evals/gpqa_eval.py +++ b/gpt_oss/evals/gpqa_eval.py @@ -9,9 +9,8 @@ import pandas from . import report -from .types import Eval, EvalResult, SamplerBase, SingleEvalResult from .abcd_grader import extract_abcd - +from .types import Eval, EvalResult, SamplerBase, SingleEvalResult QUERY_TEMPLATE_MULTICHOICE = """ {Question} @@ -34,7 +33,9 @@ def __init__( self, n_repeats: int = 8, variant: str = "diamond", - num_examples: int | None = None, # restrict to a subset of the data for debugging + num_examples: ( + int | None + ) = None, # restrict to a subset of the data for debugging debug: bool = False, n_threads: int = 1, ): @@ -44,15 +45,23 @@ def __init__( rng = random.Random(0) if debug: - examples = [row.to_dict() for _, row in df.iterrows() if "ESPRESSO spectrograph, please" in row["Question"]] + examples = [ + row.to_dict() + for _, row in df.iterrows() + if "ESPRESSO spectrograph, please" in row["Question"] + ] else: examples = [row.to_dict() for _, row in df.iterrows()] if num_examples: - assert n_repeats == 1, "n_repeats only supported for num_examples = None" + assert ( + n_repeats == 1 + ), "n_repeats only supported for num_examples = None" examples = rng.sample(examples, num_examples) examples = examples * n_repeats - examples = [example | {"permutation": rng.sample(range(4), 4)} for example in examples] + examples = [ + example | {"permutation": rng.sample(range(4), 4)} for example in examples + ] self.examples = examples self.n_repeats = n_repeats self.n_threads = n_threads @@ -69,7 +78,11 @@ def fn(row: dict): correct_index = choices.index(row["Correct Answer"]) correct_answer = "ABCD"[correct_index] choices_dict = dict( - A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=row["Question"] + A=choices[0], + B=choices[1], + C=choices[2], + D=choices[3], + Question=row["Question"], ) prompt_messages = [ sampler._pack_message( @@ -78,7 +91,9 @@ def fn(row: dict): ] sampler_response = sampler(prompt_messages) response_text = sampler_response.response_text - actual_queried_prompt_messages = sampler_response.actual_queried_message_list + actual_queried_prompt_messages = ( + sampler_response.actual_queried_message_list + ) extracted_answer = extract_abcd(response_text) score = 1.0 if extracted_answer == correct_answer else 0.0 html = report.jinja_env.from_string(report.HTML_JINJA).render( @@ -88,12 +103,19 @@ def fn(row: dict): correct_answer=correct_answer, extracted_answer=extracted_answer, ) - convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")] + convo = actual_queried_prompt_messages + [ + dict(content=response_text, role="assistant") + ] return SingleEvalResult( - html=html, score=score, convo=convo, metrics={"chars": len(response_text)} + html=html, + score=score, + convo=convo, + metrics={"chars": len(response_text)}, ) - results = report.map_with_progress(fn, self.examples, num_threads=self.n_threads) + results = report.map_with_progress( + fn, self.examples, num_threads=self.n_threads + ) return report.aggregate_results(results) @@ -122,4 +144,4 @@ def fn(row: dict): print("--------------------------------") pass_rate = passes / len(results["convos"]) - print(f"pass@1: {pass_rate}") \ No newline at end of file + print(f"pass@1: {pass_rate}") diff --git a/gpt_oss/evals/healthbench_eval.py b/gpt_oss/evals/healthbench_eval.py index 09d184c..1467bad 100644 --- a/gpt_oss/evals/healthbench_eval.py +++ b/gpt_oss/evals/healthbench_eval.py @@ -26,10 +26,8 @@ import numpy as np from . import report -from .chat_completions_sampler import ( - OPENAI_SYSTEM_MESSAGE_API, - ChatCompletionsSampler, -) +from .chat_completions_sampler import (OPENAI_SYSTEM_MESSAGE_API, + ChatCompletionsSampler) from .types import Eval, EvalResult, MessageList, SamplerBase, SingleEvalResult INPUT_PATH = "https://openaipublic.blob.core.windows.net/simple-evals/healthbench/2025-05-07-06-14-12_oss_eval.jsonl" @@ -257,14 +255,12 @@ def __init__( subset_name: Literal["hard", "consensus"] | None = None, ): if run_reference_completions: - assert physician_completions_mode is not None, ( - "physician_completions_mode must be provided if run_reference_completions is True" - ) + assert ( + physician_completions_mode is not None + ), "physician_completions_mode must be provided if run_reference_completions is True" assert PHYSICIAN_COMPLETION_MODES[physician_completions_mode][ "has_reference" - ], ( - "physician_completions_mode must have reference completions if run_reference_completions is True" - ) + ], "physician_completions_mode must have reference completions if run_reference_completions is True" if subset_name == "hard": input_path = INPUT_PATH_HARD @@ -284,9 +280,9 @@ def __init__( # physician completions mode self.physician_completions_mode = physician_completions_mode if self.physician_completions_mode is not None: - assert self.physician_completions_mode in PHYSICIAN_COMPLETION_MODES, ( - f"Invalid physician completions mode: {self.physician_completions_mode}; must be one of {PHYSICIAN_COMPLETION_MODES.keys()}" - ) + assert ( + self.physician_completions_mode in PHYSICIAN_COMPLETION_MODES + ), f"Invalid physician completions mode: {self.physician_completions_mode}; must be one of {PHYSICIAN_COMPLETION_MODES.keys()}" # subset to only the rows which have physician completions from that group examples_matching_mode = [ example diff --git a/gpt_oss/evals/report.py b/gpt_oss/evals/report.py index 787dd1f..3d2a028 100644 --- a/gpt_oss/evals/report.py +++ b/gpt_oss/evals/report.py @@ -9,7 +9,6 @@ from .types import EvalResult, Message, SingleEvalResult - HTML_JINJA = """

Prompt conversation

{% for message in prompt_messages %} diff --git a/gpt_oss/evals/responses_sampler.py b/gpt_oss/evals/responses_sampler.py index fd9daef..4d76250 100644 --- a/gpt_oss/evals/responses_sampler.py +++ b/gpt_oss/evals/responses_sampler.py @@ -22,7 +22,7 @@ def __init__( reasoning_effort: str | None = None, base_url: str = "http://localhost:8000/v1", ): - self.client = OpenAI(base_url=base_url, timeout=24*60*60) + self.client = OpenAI(base_url=base_url, timeout=24 * 60 * 60) self.model = model self.developer_message = developer_message self.temperature = temperature @@ -63,7 +63,11 @@ def __call__(self, message_list: MessageList) -> SamplerResponse: for output in response.output: if hasattr(output, "text"): - message_list.append(self._pack_message(getattr(output, "role", "assistant"), output.text)) + message_list.append( + self._pack_message( + getattr(output, "role", "assistant"), output.text + ) + ) elif hasattr(output, "content"): for c in output.content: # c.text handled below diff --git a/gpt_oss/evals/types.py b/gpt_oss/evals/types.py index 8f2b42d..3887095 100644 --- a/gpt_oss/evals/types.py +++ b/gpt_oss/evals/types.py @@ -5,16 +5,17 @@ MessageList = list[Message] - @dataclass class SamplerResponse: """ Response from a sampler. """ + response_text: str actual_queried_message_list: MessageList response_metadata: dict[str, Any] + class SamplerBase: """ Base class for defining a sampling model, which can be evaluated, @@ -22,7 +23,7 @@ class SamplerBase: """ def __call__( - self, + self, message_list: MessageList, ) -> SamplerResponse: raise NotImplementedError @@ -63,4 +64,3 @@ class Eval: def __call__(self, sampler: SamplerBase) -> EvalResult: raise NotImplementedError - diff --git a/gpt_oss/generate.py b/gpt_oss/generate.py index c075580..728093e 100644 --- a/gpt_oss/generate.py +++ b/gpt_oss/generate.py @@ -1,6 +1,7 @@ # Model parallel inference # Note: This script is for demonstration purposes only. It is not designed for production use. # See gpt_oss.chat for a more complete example with the Harmony parser. +# Example: # torchrun --nproc-per-node=4 -m gpt_oss.generate -p "why did the chicken cross the road?" model/ import argparse @@ -11,30 +12,47 @@ def main(args): match args.backend: case "torch": - from gpt_oss.torch.utils import init_distributed from gpt_oss.torch.model import TokenGenerator as TorchGenerator + from gpt_oss.torch.utils import init_distributed + device = init_distributed() generator = TorchGenerator(args.checkpoint, device=device) + case "triton": from gpt_oss.torch.utils import init_distributed from gpt_oss.triton.model import TokenGenerator as TritonGenerator + device = init_distributed() - generator = TritonGenerator(args.checkpoint, context=args.context_length, device=device) + generator = TritonGenerator( + args.checkpoint, + context=args.context_length, + device=device + ) + case "vllm": from gpt_oss.vllm.token_generator import TokenGenerator as VLLMGenerator - generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=args.tensor_parallel_size) + generator = VLLMGenerator( + args.checkpoint, + tensor_parallel_size=args.tensor_parallel_size + ) + case _: raise ValueError(f"Invalid backend: {args.backend}") tokenizer = get_tokenizer() tokens = tokenizer.encode(args.prompt) max_tokens = None if args.limit == 0 else args.limit - for token, logprob in generator.generate(tokens, stop_tokens=[tokenizer.eot_token], temperature=args.temperature, max_tokens=max_tokens, return_logprobs=True): + + for token, logprob in generator.generate( + tokens, + stop_tokens=[tokenizer.eot_token], + temperature=args.temperature, + max_tokens=max_tokens, + return_logprobs=True, + ): tokens.append(token) token_text = tokenizer.decode([token]) - print( - f"Generated token: {repr(token_text)}, logprob: {logprob}" - ) + print(f"Generated token: {repr(token_text)}, logprob: {logprob}") if __name__ == "__main__": @@ -91,5 +109,4 @@ def main(args): help="Context length for Triton backend", ) args = parser.parse_args() - main(args) diff --git a/gpt_oss/logging_config.py b/gpt_oss/logging_config.py new file mode 100644 index 0000000..42a7a3c --- /dev/null +++ b/gpt_oss/logging_config.py @@ -0,0 +1,90 @@ +""" +Logging configuration for gpt-oss package. +Provides structured logging with appropriate levels and formatting. +""" + +import logging +import sys +from typing import Any, Dict, Optional + +import structlog + +# Configure structlog processors +structlog.configure( + processors=[ + # Add log level to log entries + structlog.stdlib.add_log_level, + # Add a timestamp in ISO format + structlog.processors.TimeStamper(fmt="iso"), + # If the "stack_info" key in the event dict is true, remove it and + # render the current stack trace in the "stack" key. + structlog.processors.StackInfoRenderer(), + # Format the exception only once, even if there are multiple loggers + structlog.processors.format_exc_info, + # Render in JSON for production, or pretty print for development + ( + structlog.dev.ConsoleRenderer() + if sys.stderr.isatty() + else structlog.processors.JSONRenderer() + ), + ], + # Our `wrapper_class` is used for passing metadata when binding + wrapper_class=structlog.stdlib.BoundLogger, + # `logger_factory` is used to create wrapped loggers that are used for OUTPUT. + logger_factory=structlog.stdlib.LoggerFactory(), + # Cache the logger for better performance + cache_logger_on_first_use=True, +) + + +def configure_logging( + level: str = "INFO", format_json: bool = False, component: Optional[str] = None +) -> None: + """ + Configure logging for the application. + + Args: + level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + format_json: Whether to format logs as JSON + component: Optional component name to include in logs + """ + logging.basicConfig( + format="%(message)s", + stream=sys.stderr, + level=getattr(logging, level.upper()), + ) + + # Update structlog processors based on configuration + processors = [ + structlog.stdlib.add_log_level, + structlog.processors.TimeStamper(fmt="iso"), + structlog.processors.StackInfoRenderer(), + structlog.processors.format_exc_info, + ] + + if component: + processors.append(structlog.processors.CallsiteParameterAdder()) + + if format_json: + processors.append(structlog.processors.JSONRenderer()) + else: + processors.append(structlog.dev.ConsoleRenderer()) + + structlog.configure(processors=processors) + + +def get_logger(name: str, **initial_values: Any) -> structlog.stdlib.BoundLogger: + """ + Get a structured logger with optional initial values. + + Args: + name: Logger name (typically __name__) + **initial_values: Key-value pairs to include in all log messages + + Returns: + Structured logger instance + """ + logger = structlog.get_logger(name) + if initial_values: + logger = logger.bind(**initial_values) + return logger diff --git a/gpt_oss/metal/examples/chat.py b/gpt_oss/metal/examples/chat.py index f29cec3..c2d8648 100755 --- a/gpt_oss/metal/examples/chat.py +++ b/gpt_oss/metal/examples/chat.py @@ -2,10 +2,9 @@ import argparse import sys - from datetime import date -from gpt_oss.metal import Context, Model +from gpt_oss.metal import Context, Model DEFAULT_PROMPT = f"""You are ChatGPT, a large language model trained by OpenAI. Knowledge cutoff: 2024-06 @@ -16,8 +15,16 @@ # Valid channels: analysis, final. Channel must be included for every message.""" -parser = argparse.ArgumentParser(description="Chat with gpt-oss", formatter_class=argparse.ArgumentDefaultsHelpFormatter) -parser.add_argument("model", metavar="PATH", type=str, help="Path to gpt-oss model in Metal inference format") +parser = argparse.ArgumentParser( + description="Chat with gpt-oss", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, +) +parser.add_argument( + "model", + metavar="PATH", + type=str, + help="Path to gpt-oss model in Metal inference format", +) parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, help="System prompt") parser.add_argument( "--context-length", type=int, default=0, help="The maximum context length" @@ -25,9 +32,7 @@ parser.add_argument( "--temperature", type=float, default=1.0, help="Sampling temperature" ) -parser.add_argument( - "--seed", type=int, default=0, help="Sampling seed" -) +parser.add_argument("--seed", type=int, default=0, help="Sampling seed") GREY = "\33[90m" diff --git a/gpt_oss/metal/examples/generate.py b/gpt_oss/metal/examples/generate.py index 3b78199..9e63848 100644 --- a/gpt_oss/metal/examples/generate.py +++ b/gpt_oss/metal/examples/generate.py @@ -5,12 +5,20 @@ from gpt_oss.metal import Context, Model - -parser = argparse.ArgumentParser(description='Chat with gpt-oss', formatter_class=argparse.ArgumentDefaultsHelpFormatter) -parser.add_argument('model', metavar='PATH', type=str, help='Path to gpt-oss checkpoint') -parser.add_argument('-p', '--prompt', type=str, required=True, help='Prompt') -parser.add_argument('-l', '--limit', type=int, default=100, help='Number of tokens to generate') -parser.add_argument('--context-length', type=int, default=0, help='The maximum context length') +parser = argparse.ArgumentParser( + description="Chat with gpt-oss", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, +) +parser.add_argument( + "model", metavar="PATH", type=str, help="Path to gpt-oss checkpoint" +) +parser.add_argument("-p", "--prompt", type=str, required=True, help="Prompt") +parser.add_argument( + "-l", "--limit", type=int, default=100, help="Number of tokens to generate" +) +parser.add_argument( + "--context-length", type=int, default=0, help="The maximum context length" +) def main(args): @@ -27,8 +35,8 @@ def main(args): while context.num_tokens - prompt_tokens < options.limit: token = context.sample() context.append(token) - print(str(tokenizer.decode(token), encoding="utf-8"), end='', flush=True) + print(str(tokenizer.decode(token), encoding="utf-8"), end="", flush=True) -if __name__ == '__main__': +if __name__ == "__main__": main(sys.argv[1:]) diff --git a/gpt_oss/metal/scripts/create-local-model.py b/gpt_oss/metal/scripts/create-local-model.py index c0de8bd..12bb5ff 100644 --- a/gpt_oss/metal/scripts/create-local-model.py +++ b/gpt_oss/metal/scripts/create-local-model.py @@ -1,22 +1,37 @@ import argparse -import os -import math -import sys -import json import itertools +import json +import math +import os import struct +import sys from uuid import UUID import tiktoken - import torch +from openai_harmony import HarmonyEncodingName, load_harmony_encoding from safetensors import safe_open from tqdm import tqdm -from openai_harmony import load_harmony_encoding, HarmonyEncodingName -parser = argparse.ArgumentParser(prog='check-mxfp4-weights.py', description='Validated MXFP4 weights') -parser.add_argument('-s', '--src', metavar='DIR', type=str, required=True, help='Path to the input checkpoint directory') -parser.add_argument('-d', '--dst', metavar='FILE', type=str, required=True, help='Path to the output model file') +parser = argparse.ArgumentParser( + prog="check-mxfp4-weights.py", description="Validated MXFP4 weights" +) +parser.add_argument( + "-s", + "--src", + metavar="DIR", + type=str, + required=True, + help="Path to the input checkpoint directory", +) +parser.add_argument( + "-d", + "--dst", + metavar="FILE", + type=str, + required=True, + help="Path to the output model file", +) o200k_base = tiktoken.get_encoding("o200k_base") @@ -44,21 +59,36 @@ "<|reversed200011|>": 200011, # unused "<|call|>": 200012, "<|refusal|>": 200013, - } + }, ) -FILE_MAGIC = struct.pack('ccccccccccccI', b'G', b'P', b'T', b'-', b'O', b'S', b'S', b' ', b'v', b'1', b'.', b'0', 0) +FILE_MAGIC = struct.pack( + "ccccccccccccI", + b"G", + b"P", + b"T", + b"-", + b"O", + b"S", + b"S", + b" ", + b"v", + b"1", + b".", + b"0", + 0, +) SPECIAL_TOKEN_UUID = { - '<|start|>': UUID('55a77c2f-8a01-4c54-8ac2-313bfc7e208d').bytes, - '<|message|>': UUID('16e40431-f47f-4b22-b59b-8b278fc30a54').bytes, - '<|end|>': UUID('fcac2f6d-4705-4f6b-b228-642accac7238').bytes, - '<|return|>': UUID('f799ff69-1992-43c4-a3d8-d831f475dc75').bytes, - '<|refusal|>': UUID('e15ba702-28c4-4292-ab8f-ffa434709128').bytes, - '<|constrain|>': UUID('c0bb14c7-6022-49da-ad08-792d67e8b470').bytes, - '<|channel|>': UUID('fd3dda11-c8ab-4033-876e-d93deb172c93').bytes, - '<|call|>': UUID('1220f796-e388-4de5-b487-fe2eb5fe03c0').bytes, - '<|untrusted|>': UUID('07d7da55-b346-4cff-8b37-7cefacf8a3e8').bytes, - '<|end_untrusted|>': UUID('f265bd9c-c717-469e-a447-920687d65d90').bytes, + "<|start|>": UUID("55a77c2f-8a01-4c54-8ac2-313bfc7e208d").bytes, + "<|message|>": UUID("16e40431-f47f-4b22-b59b-8b278fc30a54").bytes, + "<|end|>": UUID("fcac2f6d-4705-4f6b-b228-642accac7238").bytes, + "<|return|>": UUID("f799ff69-1992-43c4-a3d8-d831f475dc75").bytes, + "<|refusal|>": UUID("e15ba702-28c4-4292-ab8f-ffa434709128").bytes, + "<|constrain|>": UUID("c0bb14c7-6022-49da-ad08-792d67e8b470").bytes, + "<|channel|>": UUID("fd3dda11-c8ab-4033-876e-d93deb172c93").bytes, + "<|call|>": UUID("1220f796-e388-4de5-b487-fe2eb5fe03c0").bytes, + "<|untrusted|>": UUID("07d7da55-b346-4cff-8b37-7cefacf8a3e8").bytes, + "<|end_untrusted|>": UUID("f265bd9c-c717-469e-a447-920687d65d90").bytes, } INCLUDE_SPECIAL_TOKENS = [ @@ -74,62 +104,65 @@ "<|end_untrusted|>", ] -GPTOSS_MODEL_UUID = UUID('df52dc86-1789-4ed0-a295-66f10508145b').bytes -APPLE_GPU_LAYOUT_UUID = UUID('229177a8-5775-4268-bfd8-d588b351c56d').bytes -TIKTOKEN_TOKENIZER_UUID = UUID('7401aded-2a95-40cb-b782-9ccebaafe72b').bytes +GPTOSS_MODEL_UUID = UUID("df52dc86-1789-4ed0-a295-66f10508145b").bytes +APPLE_GPU_LAYOUT_UUID = UUID("229177a8-5775-4268-bfd8-d588b351c56d").bytes +TIKTOKEN_TOKENIZER_UUID = UUID("7401aded-2a95-40cb-b782-9ccebaafe72b").bytes UE8_OFFSET = 14 # bias to MXFP4 block scales + def write_file_header(f): f.write(FILE_MAGIC) -def write_tokenizer_header(f, - num_special_tokens: int, - num_text_tokens: int, - regex_size: int, - tokens_size: int): + +def write_tokenizer_header( + f, num_special_tokens: int, num_text_tokens: int, regex_size: int, tokens_size: int +): f.write(TIKTOKEN_TOKENIZER_UUID) - f.write(struct.pack(' 0 - dst.write(struct.pack(' ReasoningE def is_not_builtin_tool(recipient: str) -> bool: - return not recipient.startswith("browser.") and not recipient == "python" and not recipient == "assistant" + return ( + not recipient.startswith("browser.") + and recipient != "python" + and recipient != "assistant" + ) + def create_api_server( - infer_next_token: Callable[[list[int], float], int], encoding: HarmonyEncoding + infer_next_token: Callable[[List[int], float], int], + encoding: HarmonyEncoding, ) -> FastAPI: app = FastAPI() - responses_store: dict[str, tuple[ResponsesRequest, ResponseObject]] = {} + responses_store: Dict[str, Tuple[ResponsesRequest, ResponseObject]] = {} def generate_response( - input_tokens: list[int], - output_tokens: list[int], + input_tokens: List[int], + output_tokens: List[int], request_body: ResponsesRequest, debug_mode: bool = False, - function_call_ids: Optional[list[tuple[str, str]]] = None, + function_call_ids: Optional[List[Tuple[str, str]]] = None, response_id: Optional[str] = None, previous_response_id: Optional[str] = None, browser_tool: Optional[SimpleBrowserTool] = None, - browser_call_ids: Optional[list[str]] = None, + browser_call_ids: Optional[List[str]] = None, ) -> ResponseObject: - output = [] - error = None + output: List[Item] = [] + error: Optional[Error] = None + if len(output_tokens) > 0: - if debug_mode: - try: - entries = encoding.parse_messages_from_completion_tokens( - output_tokens, Role.ASSISTANT - ) - except Exception as e: - print(f"Error parsing tokens: {e}") - error = Error( - code="invalid_function_call", - message=f"{e}", - ) - entries = [] - else: + # Parse assistant messages from completion tokens + try: entries = encoding.parse_messages_from_completion_tokens( output_tokens, Role.ASSISTANT ) + except Exception as e: + # In debug mode, surface a structured error + if debug_mode: + print(f"Error parsing tokens: {e}") + error = Error(code="invalid_function_call", message=str(e)) + entries = [] + else: + raise fc_index = 0 browser_tool_index = 0 for entry in entries: entry_dict = entry.to_dict() - if len(entry_dict.get("recipient", "")) > 0 and is_not_builtin_tool(entry_dict["recipient"]): + + # Function tool calls (non-builtin) + if len(entry_dict.get("recipient", "")) > 0 and is_not_builtin_tool( + entry_dict["recipient"] + ): call = entry_dict["content"][0] arguments = call["text"] name = entry_dict["recipient"] - if name.startswith("functions."): name = name[len("functions.") :] + if function_call_ids and fc_index < len(function_call_ids): fc_id, call_id = function_call_ids[fc_index] else: @@ -130,6 +138,7 @@ def generate_response( f"call_{uuid.uuid4().hex}", ) fc_index += 1 + output.append( FunctionCallItem( type="function_call", @@ -139,32 +148,34 @@ def generate_response( call_id=call_id, ) ) - elif len(entry_dict.get("recipient", "")) > 0 and entry_dict["recipient"].startswith("browser.") and browser_tool is not None: - # Mirror event-based creation of WebSearchCallItems when the browser tool is invoked + + # Browser tool calls + elif ( + len(entry_dict.get("recipient", "")) > 0 + and entry_dict["recipient"].startswith("browser.") + and browser_tool is not None + ): name = entry_dict["recipient"] call = entry_dict["content"][0] arguments = call["text"] - function_name = name[len("browser."):] + function_name = name[len("browser.") :] - # Reconstruct a Message for argument parsing tool_msg = ( Message.from_role_and_content(Role.ASSISTANT, arguments) .with_recipient(name) .with_channel("analysis") ) - action = None + action: Optional[WebSearchActionSearch | WebSearchActionOpenPage | WebSearchActionFind] = None try: parsed_args = browser_tool.process_arguments(tool_msg) if function_name == "search": action = WebSearchActionSearch( - type="search", - query=parsed_args["query"], + type="search", query=parsed_args["query"] ) elif function_name == "open": action = WebSearchActionOpenPage( - type="open_page", - url=parsed_args["url"], + type="open_page", url=parsed_args["url"] ) elif function_name == "find": action = WebSearchActionFind( @@ -184,16 +195,18 @@ def generate_response( browser_tool_index += 1 output.append( WebSearchCallItem( - type="web_search_call", - id=web_search_call_id, - action=action, + type="web_search_call", id=web_search_call_id, action=action ) ) - elif entry_dict["channel"] == "final": - content = [] - for content_entry in entry_dict["content"]: + + # Final channel message + elif entry_dict.get("channel") == "final": + content: List[TextContentItem] = [] + for content_entry in entry_dict["content"]: if browser_tool: - text_content, annotation_entries, _has_partial_citations = browser_tool.normalize_citations(content_entry["text"]) + text_content, annotation_entries, _has_partial = ( + browser_tool.normalize_citations(content_entry["text"]) + ) annotations = [UrlCitation(**a) for a in annotation_entries] else: text_content = content_entry["text"] @@ -201,9 +214,7 @@ def generate_response( content.append( TextContentItem( - type="output_text", - text=text_content, - annotations=annotations, + type="output_text", text=text_content, annotations=annotations ) ) @@ -215,24 +226,18 @@ def generate_response( status="completed", ) ) - elif entry_dict["channel"] == "analysis": - summary = [] + + # Analysis channel (reasoning) + elif entry_dict.get("channel") == "analysis": content = [ ReasoningTextContentItem( - type="reasoning_text", - text=entry["text"], + type="reasoning_text", text=entry["text"] ) for entry in entry_dict["content"] ] output.append( - ReasoningItem( - type="reasoning", - summary=summary, - content=content, - ) + ReasoningItem(type="reasoning", summary=[], content=content) ) - else: - output = [] usage = ( Usage( @@ -246,18 +251,15 @@ def generate_response( try: debug_str = encoding.decode_utf8(input_tokens + output_tokens) - except Exception: - debug_str = input_tokens + output_tokens - try: debug_input_str = encoding.decode_utf8(input_tokens) - except Exception: - debug_input_str = input_tokens - try: debug_output_str = encoding.decode_utf8(output_tokens) except Exception: - debug_output_str = output_tokens + # Fallback to raw tokens when decode fails + debug_str = input_tokens + output_tokens # type: ignore[assignment] + debug_input_str = input_tokens # type: ignore[assignment] + debug_output_str = output_tokens # type: ignore[assignment] - metadata = ( + metadata: Dict[str, Any] = ( { "__debug": debug_str, "__debug_input": debug_input_str, @@ -281,18 +283,17 @@ def generate_response( ) class StreamResponsesEvents: - initial_tokens: list[int] - tokens: list[int] - output_tokens: list[int] + initial_tokens: List[int] + tokens: List[int] + output_tokens: List[int] output_text: str request_body: ResponsesRequest request: Request sequence_number: int - def __init__( self, - initial_tokens, + initial_tokens: List[int], request_body: ResponsesRequest, as_sse: bool = False, request: Optional[Request] = None, @@ -301,7 +302,7 @@ def __init__( Callable[[str, ResponsesRequest, ResponseObject], None] ] = None, browser_tool: Optional[SimpleBrowserTool] = None, - ): + ) -> None: self.initial_tokens = initial_tokens self.tokens = initial_tokens.copy() self.output_tokens = [] @@ -309,10 +310,7 @@ def __init__( self.request_body = request_body self.parser = StreamableParser(encoding, role=Role.ASSISTANT) self.as_sse = as_sse - self.debug_mode = request_body.metadata.get( - "__debug", False - ) # we use this for demo purposes - # Set temperature for this stream, fallback to DEFAULT_TEMPERATURE if not set + self.debug_mode = request_body.metadata.get("__debug", False) self.temperature = ( request_body.temperature if request_body.temperature is not None @@ -320,13 +318,13 @@ def __init__( ) self.request = request self.sequence_number = 0 - self.function_call_ids: list[tuple[str, str]] = [] + self.function_call_ids: List[Tuple[str, str]] = [] self.response_id = response_id self.store_callback = store_callback self.new_request = True self.browser_tool = browser_tool self.use_browser_tool = browser_tool is not None - self.browser_call_ids: list[str] = [] + self.browser_call_ids: List[str] = [] def _send_event(self, event: ResponseEvent): event.sequence_number = self.sequence_number @@ -348,46 +346,39 @@ async def run(self): previous_response_id=self.request_body.previous_response_id, ) initial_response.status = "in_progress" + yield self._send_event( - ResponseCreatedEvent( - type="response.created", - response=initial_response, - ) + ResponseCreatedEvent(type="response.created", response=initial_response) ) yield self._send_event( ResponseInProgressEvent( - type="response.in_progress", - response=initial_response, + type="response.in_progress", response=initial_response ) ) - current_content_index = ( - 0 # for this implementation we will always have one content item only - ) + current_content_index = 0 current_output_index = -1 sent_output_item_added = False - # we use this if the model outputs a citation to buffer until completed - output_delta_buffer = "" - # we use this to track the current output text content for things like providing the right indices in citations - current_output_text_content = "" - current_annotations = [] + output_delta_buffer = "" + current_output_text_content = "" + current_annotations: List[Dict[str, Any]] = [] while True: - # Check for client disconnect if self.request is not None and await self.request.is_disconnected(): print("Client disconnected, stopping token generation.") break + next_tok = infer_next_token( - self.tokens, - temperature=self.temperature, - new_request=self.new_request, + self.tokens, temperature=self.temperature, new_request=self.new_request ) self.new_request = False self.tokens.append(next_tok) + try: self.parser.process(next_tok) - except Exception as e: + except Exception: + # Parser may raise before enough tokens; safe to continue pass if self.parser.state == StreamState.EXPECT_START: @@ -396,12 +387,10 @@ async def run(self): if len(self.parser.messages) > 0: previous_item = self.parser.messages[-1] + if previous_item.recipient is not None: recipient = previous_item.recipient - if ( - not recipient.startswith("browser.") - and not recipient == "python" - ): + if not recipient.startswith("browser.") and recipient != "python": fc_id = f"fc_{uuid.uuid4().hex}" call_id = f"call_{uuid.uuid4().hex}" self.function_call_ids.append((fc_id, call_id)) @@ -412,13 +401,9 @@ async def run(self): item=FunctionCallItem( type="function_call", name=( - previous_item.recipient[ - len("functions.") : - ] - if previous_item.recipient.startswith( - "functions." - ) - else previous_item.recipient + recipient[len("functions.") :] + if recipient.startswith("functions.") + else recipient ), arguments=previous_item.content[0].text, id=fc_id, @@ -426,6 +411,7 @@ async def run(self): ), ) ) + if previous_item.channel == "analysis": yield self._send_event( ResponseReasoningTextDone( @@ -462,17 +448,19 @@ async def run(self): ), ) ) + if previous_item.channel == "final": annotations = [UrlCitation(**a) for a in current_annotations] if browser_tool: - normalized_text, _annotations, _has_partial_citations = browser_tool.normalize_citations(previous_item.content[0].text) + normalized_text, _annos, _has_partial = browser_tool.normalize_citations( + previous_item.content[0].text + ) else: normalized_text = previous_item.content[0].text annotations = [] + text_content = TextContentItem( - type="output_text", - text=normalized_text, - annotations=annotations, + type="output_text", text=normalized_text, annotations=annotations ) yield self._send_event( ResponseOutputTextDone( @@ -494,16 +482,13 @@ async def run(self): ResponseOutputItemDone( type="response.output_item.done", output_index=current_output_index, - item=Item( - type="message", - role="assistant", - content=[text_content], - ), + item=Item(type="message", role="assistant", content=[text_content]), ) ) current_annotations = [] current_output_text_content = "" + # Streaming assistant output (final channel) if ( self.parser.last_content_delta and self.parser.current_channel == "final" @@ -529,16 +514,17 @@ async def run(self): output_delta_buffer += self.parser.last_content_delta should_send_output_text_delta = True + if browser_tool: - # we normalize on the full current text to get the right indices in citations - updated_output_text, annotations, has_partial_citations = browser_tool.normalize_citations(current_output_text_content + output_delta_buffer) - # remove the current text to get back the delta but now normalized - output_delta_buffer = updated_output_text[len(current_output_text_content):] - - # Filter annotations to only include those whose start_index is not already present in current_annotations - # this is to avoid sending duplicate annotations as multiple annotations can't be in the same place + updated_output_text, annotations, has_partial = browser_tool.normalize_citations( + current_output_text_content + output_delta_buffer + ) + output_delta_buffer = updated_output_text[len(current_output_text_content) :] + existing_start_indices = {a["start_index"] for a in current_annotations} - new_annotations = [a for a in annotations if a["start_index"] not in existing_start_indices] + new_annotations = [ + a for a in annotations if a["start_index"] not in existing_start_indices + ] for a in new_annotations: current_annotations.append(a) citation = UrlCitation(**a) @@ -552,10 +538,9 @@ async def run(self): ) ) - if has_partial_citations: + if has_partial: should_send_output_text_delta = False - if should_send_output_text_delta: yield self._send_event( ResponseOutputTextDelta( @@ -568,6 +553,7 @@ async def run(self): current_output_text_content += output_delta_buffer output_delta_buffer = "" + # Streaming reasoning (analysis channel) if ( self.parser.last_content_delta and self.parser.current_channel == "analysis" @@ -579,9 +565,7 @@ async def run(self): ResponseOutputItemAdded( type="response.output_item.added", output_index=current_output_index, - item=ReasoningItem( - type="reasoning", summary=[], content=[] - ), + item=ReasoningItem(type="reasoning", summary=[], content=[]), ) ) yield self._send_event( @@ -592,6 +576,7 @@ async def run(self): part=ReasoningTextContentItem(type="reasoning_text", text=""), ) ) + yield self._send_event( ResponseReasoningTextDelta( type="response.reasoning_text.delta", @@ -602,59 +587,54 @@ async def run(self): ) try: - # purely for debugging purposes output_token_text = encoding.decode_utf8([next_tok]) self.output_text += output_token_text print(output_token_text, end="", flush=True) - except RuntimeError: pass if next_tok in encoding.stop_tokens_for_assistant_actions(): if len(self.parser.messages) > 0: last_message = self.parser.messages[-1] + # Handle browser tool invocation mid-stream if ( self.use_browser_tool and last_message.recipient is not None and last_message.recipient.startswith("browser.") ): - function_name = last_message.recipient[len("browser."):] + function_name = last_message.recipient[len("browser.") :] action = None parsed_args = browser_tool.process_arguments(last_message) if function_name == "search": - action = WebSearchActionSearch( - type="search", - query=parsed_args["query"], - ) + action = WebSearchActionSearch(type="search", query=parsed_args["query"]) elif function_name == "open": - action = WebSearchActionOpenPage( - type="open_page", - url=parsed_args["url"] if "url" in parsed_args else None, - ) + action = WebSearchActionOpenPage(type="open_page", url=parsed_args.get("url")) elif function_name == "find": action = WebSearchActionFind( type="find", pattern=parsed_args["pattern"], - url=parsed_args["url"] if "url" in parsed_args else None, + url=parsed_args.get("url"), ) if action is not None: web_search_call_id = f"ws_{uuid.uuid4().hex}" self.browser_call_ids.append(web_search_call_id) - yield self._send_event(ResponseOutputItemAdded( - type="response.output_item.added", - output_index=current_output_index, - item=WebSearchCallItem( - type="web_search_call", - id=web_search_call_id, - action=action, - ), - )) + yield self._send_event( + ResponseOutputItemAdded( + type="response.output_item.added", + output_index=current_output_index, + item=WebSearchCallItem( + type="web_search_call", + id=web_search_call_id, + action=action, + ), + ) + ) yield self._send_event( ResponseWebSearchCallInProgress( type="response.web_search_call.in_progress", output_index=current_output_index, - id=web_search_call_id + id=web_search_call_id, ) ) @@ -676,10 +656,12 @@ async def run_tool(): new_tokens = encoding.render_conversation_for_completion( Conversation.from_messages(result), Role.ASSISTANT ) - + print(encoding.decode_utf8(new_tokens)) self.output_tokens.append(next_tok) - self.tokens.append(encoding.encode('<|end|>', allowed_special="all")[0]) + self.tokens.append( + encoding.encode("<|end|>", allowed_special="all")[0] + ) for token in new_tokens: self.parser.process(token) @@ -693,29 +675,29 @@ async def run_tool(): id=web_search_call_id, ) ) - yield self._send_event(ResponseOutputItemDone( - type="response.output_item.done", - output_index=current_output_index, - item=WebSearchCallItem( - type="web_search_call", - id=web_search_call_id, - action=action, - ), - )) + yield self._send_event( + ResponseOutputItemDone( + type="response.output_item.done", + output_index=current_output_index, + item=WebSearchCallItem( + type="web_search_call", + id=web_search_call_id, + action=action, + ), + ) + ) current_output_index += 1 self.new_request = True - continue - else: break else: raise ValueError("No messages to process") + if len(self.output_tokens) >= self.request_body.max_output_tokens: break - # Adding in the end if we know we are not done self.output_tokens.append(next_tok) if self.request is None or not await self.request.is_disconnected(): @@ -733,28 +715,20 @@ async def run_tool(): if self.store_callback and self.request_body.store: self.store_callback(self.response_id, self.request_body, response) yield self._send_event( - ResponseCompletedEvent( - type="response.completed", - response=response, - ) + ResponseCompletedEvent(type="response.completed", response=response) ) @app.post("/v1/responses", response_model=ResponseObject) async def generate(body: ResponsesRequest, request: Request): print("request received") - use_browser_tool = any( - getattr(tool, "type", None) == "browser_search" - for tool in (body.tools or []) - ) + tools_list = list(getattr(body, "tools", []) or []) + use_browser_tool = any(getattr(tool, "type", None) == "browser_search" for tool in tools_list) + browser_tool: Optional[SimpleBrowserTool] = None if use_browser_tool: - backend = ExaBackend( - source="web", - ) + backend = ExaBackend(source="web") browser_tool = SimpleBrowserTool(backend=backend) - else: - browser_tool = None if body.previous_response_id: prev = responses_store.get(body.previous_response_id) @@ -779,11 +753,10 @@ def _ensure_list(inp): body.instructions = prev_req.instructions body.input = merged_input - system_message_content = SystemContent.new().with_conversation_start_date( datetime.datetime.now().strftime("%Y-%m-%d") ) - + if body.reasoning is not None: try: @@ -793,116 +766,98 @@ def _ensure_list(inp): raise HTTPException(status_code=422, detail=str(e)) system_message_content = system_message_content.with_reasoning_effort(reasoning_effort) - if use_browser_tool: + if use_browser_tool and browser_tool is not None: system_message_content = system_message_content.with_tools(browser_tool.tool_config) - system_message = Message.from_role_and_content( - Role.SYSTEM, system_message_content - ) - messages = [system_message] + system_message = Message.from_role_and_content(Role.SYSTEM, system_message_content) + messages: List[Message] = [system_message] - if body.instructions or body.tools: - developer_message_content = DeveloperContent.new().with_instructions( - body.instructions - ) + # Developer instructions and function tools + if body.instructions or tools_list: + developer_message_content = DeveloperContent.new().with_instructions(body.instructions) - tools = [] - for tool in body.tools: - if tool.type == "function": - tools.append( - ToolDescription.new( - tool.name, - tool.description, - tool.parameters, - ) + # Resolve conflict: build function tools once + function_tools: List[ToolDescription] = [] + for tool in tools_list: + if getattr(tool, "type", None) == "function": + function_tools.append( + ToolDescription.new(tool.name, tool.description, tool.parameters) ) - if tools: - developer_message_content = developer_message_content.with_function_tools( - tools - ) - - developer_message = Message.from_role_and_content( - Role.DEVELOPER, developer_message_content - ) + if function_tools: + developer_message_content = developer_message_content.with_function_tools(function_tools) + developer_message = Message.from_role_and_content(Role.DEVELOPER, developer_message_content) messages.append(developer_message) + # User and prior content if isinstance(body.input, str): user_message = Message.from_role_and_content(Role.USER, body.input) messages.append(user_message) else: is_last_message_function_call_output = ( - len(body.input) > 0 and body.input[-1].type == "function_call_output" + len(body.input) > 0 and getattr(body.input[-1], "type", None) == "function_call_output" ) - function_call_map = {} - # Find the index of the last assistant message + function_call_map: Dict[str, FunctionCallItem] = {} last_assistant_idx = -1 for idx, item in enumerate(body.input): - if item.type == "message" and item.role == Role.ASSISTANT: + if getattr(item, "type", None) == "message" and getattr(item, "role", None) == Role.ASSISTANT: last_assistant_idx = idx for idx, item in enumerate(body.input): - if item.type == "message": - # TODO: add system prompt handling - if isinstance(item.content, str): - messages.append( - Message.from_role_and_content(item.role, item.content) - ) + itype = getattr(item, "type", None) + if itype == "message": + content = getattr(item, "content", None) + role = getattr(item, "role", None) + if isinstance(content, str): + messages.append(Message.from_role_and_content(role, content)) else: - for content_item in item.content: - messages.append( - Message.from_role_and_content(item.role, content_item.text) - ) - # add final channel to the last assistant message if it's from the assistant - if item.role == Role.ASSISTANT: + for content_item in content: + messages.append(Message.from_role_and_content(role, content_item.text)) + if role == Role.ASSISTANT: messages[-1] = messages[-1].with_channel("final") - elif item.type == "reasoning": - # Only include reasoning if it is after the last assistant message and we are handling a function call at the moment - if ( - idx > last_assistant_idx - and is_last_message_function_call_output - ): + + elif itype == "reasoning": + if idx > last_assistant_idx and is_last_message_function_call_output: for content_item in item.content: messages.append( - Message.from_role_and_content( - Role.ASSISTANT, content_item.text - ).with_channel("analysis") + Message.from_role_and_content(Role.ASSISTANT, content_item.text).with_channel("analysis") ) - elif item.type == "function_call": - function_call_map[item.call_id] = item + + elif itype == "function_call": + function_call_map[item.call_id] = item # type: ignore[index] messages.append( - Message.from_role_and_content(Role.ASSISTANT, item.arguments) - .with_recipient(f"functions.{item.name}") + Message.from_role_and_content(Role.ASSISTANT, item.arguments) # type: ignore[arg-type] + .with_recipient(f"functions.{item.name}") # type: ignore[attr-defined] .with_channel("commentary") ) - elif item.type == "function_call_output": - function_call = function_call_map.get(item.call_id, None) + + elif itype == "function_call_output": + function_call = function_call_map.get(item.call_id) # type: ignore[index] if not function_call: raise ValueError(f"Function call {item.call_id} not found") messages.append( Message.from_author_and_content( Author.new(Role.TOOL, f"functions.{function_call.name}"), - item.output, - ).with_recipient("assistant").with_channel("commentary") + item.output, # type: ignore[arg-type] + ) + .with_recipient("assistant") + .with_channel("commentary") ) conversation = Conversation.from_messages(messages) - - initial_tokens = encoding.render_conversation_for_completion( - conversation, Role.ASSISTANT - ) + initial_tokens = encoding.render_conversation_for_completion(conversation, Role.ASSISTANT) print(encoding.decode_utf8(initial_tokens)) response_id = f"resp_{uuid.uuid4().hex}" - def store_callback(rid: str, req: ResponsesRequest, resp: ResponseObject): + def store_callback(rid: str, req: ResponsesRequest, resp: ResponseObject) -> None: responses_store[rid] = (req, resp) event_stream = StreamResponsesEvents( initial_tokens, body, - as_sse=body.stream, + as_sse=bool(body.stream), request=request, response_id=response_id, store_callback=store_callback, @@ -915,7 +870,7 @@ def store_callback(rid: str, req: ResponsesRequest, resp: ResponseObject): last_event = None async for event in event_stream.run(): last_event = event - + # last_event is a ResponseCompletedEvent return last_event.response return app diff --git a/gpt_oss/responses_api/events.py b/gpt_oss/responses_api/events.py index 7adecc6..be63cf9 100644 --- a/gpt_oss/responses_api/events.py +++ b/gpt_oss/responses_api/events.py @@ -3,16 +3,9 @@ from pydantic import BaseModel -from .types import ( - FunctionCallItem, - Item, - ReasoningItem, - ResponseObject, - TextContentItem, - ReasoningTextContentItem, - WebSearchCallItem, - UrlCitation, -) +from .types import (FunctionCallItem, Item, ReasoningItem, + ReasoningTextContentItem, ResponseObject, TextContentItem, + UrlCitation, WebSearchCallItem) class ResponseEvent(BaseModel): @@ -105,25 +98,37 @@ class ResponseContentPartDone(ResponseEvent): content_index: int = 0 part: Union[TextContentItem, ReasoningTextContentItem] + class ResponseOutputTextAnnotationAdded(ResponseEvent): - type: Literal["response.output_text.annotation.added"] = "response.output_text.annotation.added" + type: Literal["response.output_text.annotation.added"] = ( + "response.output_text.annotation.added" + ) item_id: str = "item_1234" output_index: int = 0 content_index: int = 0 annotation_index: int = 0 annotation: UrlCitation + class ResponseWebSearchCallInProgress(ResponseEvent): - type: Literal["response.web_search_call.in_progress"] = "response.web_search_call.in_progress" + type: Literal["response.web_search_call.in_progress"] = ( + "response.web_search_call.in_progress" + ) output_index: int = 0 item_id: str = "item_1234" + class ResponseWebSearchCallSearching(ResponseEvent): - type: Literal["response.web_search_call.searching"] = "response.web_search_call.searching" + type: Literal["response.web_search_call.searching"] = ( + "response.web_search_call.searching" + ) output_index: int = 0 item_id: str = "item_1234" + class ResponseWebSearchCallCompleted(ResponseEvent): - type: Literal["response.web_search_call.completed"] = "response.web_search_call.completed" + type: Literal["response.web_search_call.completed"] = ( + "response.web_search_call.completed" + ) output_index: int = 0 - item_id: str = "item_1234" \ No newline at end of file + item_id: str = "item_1234" diff --git a/gpt_oss/responses_api/inference/ollama.py b/gpt_oss/responses_api/inference/ollama.py index 35eb1b2..43a87c7 100644 --- a/gpt_oss/responses_api/inference/ollama.py +++ b/gpt_oss/responses_api/inference/ollama.py @@ -8,17 +8,17 @@ import threading import time from typing import Callable, Optional -import requests -from openai_harmony import load_harmony_encoding, HarmonyEncodingName +import requests +from openai_harmony import HarmonyEncodingName, load_harmony_encoding EOS_TOKEN = 200002 # only used on hard timeout # Tunables -POLL_INTERVAL_S = 0.01 # 10ms between buffer checks -CALL_MAX_WAIT_S = 0.250 # max time to block inside a single infer call -NO_TOKEN_TIMEOUT_S = 15.0 # overall inactivity timeout before emitting EOS -FIRST_BYTE_TIMEOUT_S = 30.0 # time to wait for first token before EOS +POLL_INTERVAL_S = 0.01 # 10ms between buffer checks +CALL_MAX_WAIT_S = 0.250 # max time to block inside a single infer call +NO_TOKEN_TIMEOUT_S = 15.0 # overall inactivity timeout before emitting EOS +FIRST_BYTE_TIMEOUT_S = 30.0 # time to wait for first token before EOS # Shared state _token_buffer: list[int] = [] @@ -26,9 +26,10 @@ _stream_thread: Optional[threading.Thread] = None _stream_done = threading.Event() _stream_error: Optional[Exception] = None -_last_progress_ts: float = 0.0 # updated whenever we enqueue or dequeue tokens +_last_progress_ts: float = 0.0 # updated whenever we enqueue or dequeue tokens _previous_request_tokens: list[int] = [] + def lcp(cache: list[int], inp: list[int]) -> list[int]: i = 0 max_len = min(len(cache), len(inp)) @@ -36,13 +37,16 @@ def lcp(cache: list[int], inp: list[int]) -> list[int]: i += 1 return cache[:i] + def _now(): return time.monotonic() + def _touch_progress(): global _last_progress_ts _last_progress_ts = _now() + def _reset_stream_state(): global _token_buffer, _stream_thread, _stream_error with _buffer_lock: @@ -52,14 +56,15 @@ def _reset_stream_state(): _stream_error = None _touch_progress() + def setup_model(checkpoint: str) -> Callable[[list[int], float, bool], int]: encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) model_name = checkpoint def _start_stream(token_ids: list[int], temperature: float): prompt_text = encoding.decode(token_ids) + def run(): - nonlocal prompt_text, temperature global _stream_error global _previous_request_tokens @@ -187,6 +192,8 @@ def infer_next_token( # If we reach here, we still haven't got a token—ask the caller to call again soon. # Return a harmless token that the server will replace/ignore if your interface supports it. # If your interface does NOT allow a sentinel, keep the short-blocking behavior above. - return EOS_TOKEN if False else 0 # replace `0` with a PAD/NOOP token your server ignores + return ( + EOS_TOKEN if False else 0 + ) # replace `0` with a PAD/NOOP token your server ignores return infer_next_token diff --git a/gpt_oss/responses_api/inference/transformers.py b/gpt_oss/responses_api/inference/transformers.py index c743707..23f6691 100644 --- a/gpt_oss/responses_api/inference/transformers.py +++ b/gpt_oss/responses_api/inference/transformers.py @@ -6,14 +6,14 @@ import os from typing import Callable, List +import torch # Transformers imports from transformers import AutoModelForCausalLM, PreTrainedModel -import torch - DEFAULT_TEMPERATURE = 0.0 TP = os.environ.get("TP", 2) + def load_model(checkpoint: str): """ Serve the model directly with the Auto API. @@ -41,10 +41,15 @@ def get_infer_next_token(model: PreTrainedModel): def infer_next_token( tokens: List[int], temperature: float = DEFAULT_TEMPERATURE, - new_request: bool = False, # kept for interface compatibility; unused here + new_request: bool = False, # kept for interface compatibility; unused here ) -> int: tokens = torch.tensor([tokens], dtype=torch.int64, device=model.device) - output = model.generate(tokens, max_new_tokens=1, do_sample=temperature != 0, temperature=temperature) + output = model.generate( + tokens, + max_new_tokens=1, + do_sample=temperature != 0, + temperature=temperature, + ) return output[0, -1].tolist() return infer_next_token diff --git a/gpt_oss/responses_api/inference/vllm.py b/gpt_oss/responses_api/inference/vllm.py index 9c07c55..b232b79 100644 --- a/gpt_oss/responses_api/inference/vllm.py +++ b/gpt_oss/responses_api/inference/vllm.py @@ -1,6 +1,6 @@ """ -NOTE: this is not the most efficient way to use vLLM. It's a simple implementation that infers -one token at a time to mimic the behavior of the Triton implementation. +NOTE: this is not the most efficient way to use vLLM. It's a simple implementation that infers +one token at a time to mimic the behavior of the Triton implementation. """ import os @@ -13,6 +13,7 @@ DEFAULT_TEMPERATURE = 0.0 TP = os.environ.get("TP", 2) + def load_model(checkpoint: str): """ Create the vLLM engine. We enable prefix caching so repeated prefixes @@ -21,9 +22,9 @@ def load_model(checkpoint: str): llm = LLM( model=checkpoint, - tensor_parallel_size=TP, # set >1 if you want TP across GPUs - enable_prefix_caching=True, # reuse KV for shared prefixes - disable_log_stats=True, # uncomment to quiet logs + tensor_parallel_size=TP, # set >1 if you want TP across GPUs + enable_prefix_caching=True, # reuse KV for shared prefixes + disable_log_stats=True, # uncomment to quiet logs ) return llm @@ -52,8 +53,8 @@ def infer_next_token( sampling = SamplingParams( temperature=float(temperature), - max_tokens=1, # we only want the next token - n=1, # single continuation + max_tokens=1, # we only want the next token + n=1, # single continuation # You can expose/enable more controls here (top_p, top_k, etc.) ) diff --git a/gpt_oss/responses_api/serve.py b/gpt_oss/responses_api/serve.py index 35fc3f4..fc5c19e 100644 --- a/gpt_oss/responses_api/serve.py +++ b/gpt_oss/responses_api/serve.py @@ -3,10 +3,7 @@ import argparse import uvicorn -from openai_harmony import ( - HarmonyEncodingName, - load_harmony_encoding, -) +from openai_harmony import HarmonyEncodingName, load_harmony_encoding from .api_server import create_api_server diff --git a/gpt_oss/responses_api/types.py b/gpt_oss/responses_api/types.py index 1d908e3..35d17a6 100644 --- a/gpt_oss/responses_api/types.py +++ b/gpt_oss/responses_api/types.py @@ -8,6 +8,7 @@ REASONING_EFFORT = ReasoningEffort.LOW DEFAULT_MAX_OUTPUT_TOKENS = 10_000 + class UrlCitation(BaseModel): type: Literal["url_citation"] end_index: int @@ -15,6 +16,7 @@ class UrlCitation(BaseModel): url: str title: str + class TextContentItem(BaseModel): type: Union[Literal["text"], Literal["input_text"], Literal["output_text"]] text: str @@ -61,25 +63,30 @@ class FunctionCallOutputItem(BaseModel): call_id: str = "call_1234" output: str + class WebSearchActionSearch(BaseModel): type: Literal["search"] query: Optional[str] = None + class WebSearchActionOpenPage(BaseModel): type: Literal["open_page"] url: Optional[str] = None + class WebSearchActionFind(BaseModel): type: Literal["find"] pattern: Optional[str] = None url: Optional[str] = None + class WebSearchCallItem(BaseModel): type: Literal["web_search_call"] id: str = "ws_1234" status: Literal["in_progress", "completed", "incomplete"] = "completed" action: Union[WebSearchActionSearch, WebSearchActionOpenPage, WebSearchActionFind] + class Error(BaseModel): code: str message: str @@ -115,7 +122,16 @@ class ResponsesRequest(BaseModel): instructions: Optional[str] = None max_output_tokens: Optional[int] = DEFAULT_MAX_OUTPUT_TOKENS input: Union[ - str, list[Union[Item, ReasoningItem, FunctionCallItem, FunctionCallOutputItem, WebSearchCallItem]] + str, + list[ + Union[ + Item, + ReasoningItem, + FunctionCallItem, + FunctionCallOutputItem, + WebSearchCallItem, + ] + ], ] model: Optional[str] = MODEL_IDENTIFIER stream: Optional[bool] = False @@ -131,7 +147,15 @@ class ResponsesRequest(BaseModel): class ResponseObject(BaseModel): - output: list[Union[Item, ReasoningItem, FunctionCallItem, FunctionCallOutputItem, WebSearchCallItem]] + output: list[ + Union[ + Item, + ReasoningItem, + FunctionCallItem, + FunctionCallOutputItem, + WebSearchCallItem, + ] + ] created_at: int usage: Optional[Usage] = None status: Literal["completed", "failed", "incomplete", "in_progress"] = "in_progress" diff --git a/gpt_oss/tokenizer.py b/gpt_oss/tokenizer.py index 866077f..c12ac29 100644 --- a/gpt_oss/tokenizer.py +++ b/gpt_oss/tokenizer.py @@ -1,5 +1,6 @@ import tiktoken + def get_tokenizer(): o200k_base = tiktoken.get_encoding("o200k_base") tokenizer = tiktoken.Encoding( @@ -23,8 +24,7 @@ def get_tokenizer(): "<|reserved_200010|>": 200010, "<|reserved_200011|>": 200011, "<|call|>": 200012, - } | { - f"<|reserved_{i}|>": i for i in range(200013, 201088) - }, + } + | {f"<|reserved_{i}|>": i for i in range(200013, 201088)}, ) return tokenizer diff --git a/gpt_oss/tools/apply_patch.py b/gpt_oss/tools/apply_patch.py index a1ecb11..007199e 100644 --- a/gpt_oss/tools/apply_patch.py +++ b/gpt_oss/tools/apply_patch.py @@ -12,14 +12,7 @@ import pathlib from dataclasses import dataclass, field from enum import Enum -from typing import ( - Callable, - Dict, - List, - Optional, - Tuple, - Union, -) +from typing import Callable, Dict, List, Optional, Tuple, Union # --------------------------------------------------------------------------- # @@ -493,11 +486,10 @@ def remove_file(path: str) -> None: pathlib.Path(path).unlink(missing_ok=True) - def apply_patch( text: str, open_fn: Callable[[str], str] = open_file, - write_fn: Callable[[str, str], None] = write_file, + write_fn: Callable[[str, str], None] = write_file, remove_fn: Callable[[str], None] = remove_file, ) -> str: if not text.startswith("*** Begin Patch"): diff --git a/gpt_oss/tools/python_docker/docker_tool.py b/gpt_oss/tools/python_docker/docker_tool.py index 7067c1e..9c2137c 100644 --- a/gpt_oss/tools/python_docker/docker_tool.py +++ b/gpt_oss/tools/python_docker/docker_tool.py @@ -1,22 +1,15 @@ # Run this before running the tool: # $ docker image pull python:3.11 +import io +import tarfile from typing import Any, AsyncIterator import docker -from openai_harmony import ( - Author, - Content, - Message, - Role, - TextContent, - ToolNamespaceConfig, -) -import io -import tarfile +from openai_harmony import (Author, Content, Message, Role, TextContent, + ToolNamespaceConfig) from ..tool import Tool - _docker_client = None @@ -84,9 +77,7 @@ def instruction(self) -> str: @property def tool_config(self) -> ToolNamespaceConfig: return ToolNamespaceConfig( - name=self.get_tool_name(), - description=self.instruction, - tools=[] + name=self.get_tool_name(), description=self.instruction, tools=[] ) def _make_response( @@ -111,7 +102,7 @@ def make_response( message = Message( author=author, content=[content], - ).with_recipient('assistant') + ).with_recipient("assistant") if channel: message = message.with_channel(channel) diff --git a/gpt_oss/tools/simple_browser/__init__.py b/gpt_oss/tools/simple_browser/__init__.py index 9043cb1..b4f2628 100644 --- a/gpt_oss/tools/simple_browser/__init__.py +++ b/gpt_oss/tools/simple_browser/__init__.py @@ -1,5 +1,5 @@ -from .simple_browser_tool import SimpleBrowserTool from .backend import ExaBackend +from .simple_browser_tool import SimpleBrowserTool __all__ = [ "SimpleBrowserTool", diff --git a/gpt_oss/tools/simple_browser/backend.py b/gpt_oss/tools/simple_browser/backend.py index 03bdf56..17ada2e 100644 --- a/gpt_oss/tools/simple_browser/backend.py +++ b/gpt_oss/tools/simple_browser/backend.py @@ -11,22 +11,12 @@ import chz from aiohttp import ClientSession, ClientTimeout -from tenacity import ( - after_log, - before_sleep_log, - retry, - retry_if_exception_type, - stop_after_attempt, - wait_exponential, -) - -from .page_contents import ( - Extract, - FetchResult, - PageContents, - get_domain, - process_html, -) +from tenacity import (after_log, before_sleep_log, retry, + retry_if_exception_type, stop_after_attempt, + wait_exponential) + +from .page_contents import (Extract, FetchResult, PageContents, get_domain, + process_html) logger = logging.getLogger(__name__) @@ -108,11 +98,11 @@ def _get_api_key(self) -> str: async def _post(self, session: ClientSession, endpoint: str, payload: dict) -> dict: headers = {"x-api-key": self._get_api_key()} - async with session.post(f"{self.BASE_URL}{endpoint}", json=payload, headers=headers) as resp: + async with session.post( + f"{self.BASE_URL}{endpoint}", json=payload, headers=headers + ) as resp: if resp.status != 200: - raise BackendError( - f"Exa API error {resp.status}: {await resp.text()}" - ) + raise BackendError(f"Exa API error {resp.status}: {await resp.text()}") return await resp.json() async def search( @@ -121,7 +111,11 @@ async def search( data = await self._post( session, "/search", - {"query": query, "numResults": topn, "contents": {"text": True, "summary": True}}, + { + "query": query, + "numResults": topn, + "contents": {"text": True, "summary": True}, + }, ) # make a simple HTML page to work with browser format titles_and_urls = [ @@ -152,7 +146,7 @@ async def fetch(self, url: str, session: ClientSession) -> PageContents: data = await self._post( session, "/contents", - {"urls": [url], "text": { "includeHtmlTags": True }}, + {"urls": [url], "text": {"includeHtmlTags": True}}, ) results = data.get("results", []) if not results: diff --git a/gpt_oss/tools/simple_browser/page_contents.py b/gpt_oss/tools/simple_browser/page_contents.py index 6fffd3f..ab81454 100644 --- a/gpt_oss/tools/simple_browser/page_contents.py +++ b/gpt_oss/tools/simple_browser/page_contents.py @@ -16,7 +16,6 @@ import lxml.etree import lxml.html import pydantic - import tiktoken logger = logging.getLogger(__name__) diff --git a/gpt_oss/tools/simple_browser/simple_browser_tool.py b/gpt_oss/tools/simple_browser/simple_browser_tool.py index 913ee0b..bf619d9 100644 --- a/gpt_oss/tools/simple_browser/simple_browser_tool.py +++ b/gpt_oss/tools/simple_browser/simple_browser_tool.py @@ -12,24 +12,12 @@ import structlog import tiktoken from aiohttp import ClientSession -from openai_harmony import ( - Author, - Content, - Message, - Role, - TextContent, - ToolNamespaceConfig -) +from openai_harmony import (Author, Content, Message, Role, TextContent, + ToolNamespaceConfig) from ..tool import Tool - # from functions import Function, from_python -from .backend import ( - VIEW_SOURCE_PREFIX, - Backend, - BackendError, - maybe_truncate, -) +from .backend import VIEW_SOURCE_PREFIX, Backend, BackendError, maybe_truncate from .page_contents import Extract, PageContents logger = structlog.stdlib.get_logger(component=__name__) @@ -44,7 +32,9 @@ ) LINK_PATTERN = re.compile(r"【\d+†(?P[^†】]+)(?:†[^†】]+)?】") -CITATION_OUTPUT_PATTERN = re.compile(r"【(?P\d+)†(?P[^†】]+)(?:†[^†】]+)?】") +CITATION_OUTPUT_PATTERN = re.compile( + r"【(?P\d+)†(?P[^†】]+)(?:†[^†】]+)?】" +) CallParams = ParamSpec("CallParams") @@ -352,12 +342,15 @@ def name(self) -> str: def tool_config(self) -> ToolNamespaceConfig: config = ToolNamespaceConfig.browser() config.name = self.name - config.description = """Tool for browsing. + config.description = ( + """Tool for browsing. The `cursor` appears in brackets before each browsing display: `[{cursor}]`. Cite information from the tool using the following format: `【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`. Do not quote more than 10 words directly from the tool output. -sources=""" + self.backend.source +sources=""" + + self.backend.source + ) return config @property @@ -616,8 +609,9 @@ def make_error_message(error: str) -> Message: else: raise ValueError("should not be here") - - def normalize_citations(self, old_content: str, hide_partial_citations: bool = False) -> tuple[str, list[dict[str, Any]], bool]: + def normalize_citations( + self, old_content: str, hide_partial_citations: bool = False + ) -> tuple[str, list[dict[str, Any]], bool]: """ Returns a tuple of (new_message, annotations, has_partial_citations) - new_message: Message with citations replaced by ([domain](url)) @@ -625,7 +619,9 @@ def normalize_citations(self, old_content: str, hide_partial_citations: bool = F - has_partial_citations: whether the text includes an unfinished citation """ - has_partial_citations = PARTIAL_FINAL_LINK_PATTERN.search(old_content) is not None + has_partial_citations = ( + PARTIAL_FINAL_LINK_PATTERN.search(old_content) is not None + ) if hide_partial_citations and has_partial_citations: old_content = PARTIAL_FINAL_LINK_PATTERN.sub("", old_content) @@ -635,12 +631,14 @@ def normalize_citations(self, old_content: str, hide_partial_citations: bool = F content = match.group("content") start_idx = match.start() end_idx = match.end() - matches.append({ - "cursor": cursor, - "content": content, - "start": start_idx, - "end": end_idx - }) + matches.append( + { + "cursor": cursor, + "content": content, + "start": start_idx, + "end": end_idx, + } + ) # Build a mapping from cursor to url cursor_to_url = {} @@ -673,13 +671,15 @@ def extract_domain(url): # The start and end indices in the new content start_index = len(new_content) end_index = start_index + len(replacement) - annotations.append({ - "start_index": start_index, - "end_index": end_index, - "title": domain, - "url": url, - "type": "url_citation", - }) + annotations.append( + { + "start_index": start_index, + "end_index": end_index, + "title": domain, + "url": url, + "type": "url_citation", + } + ) new_content += replacement else: # Keep the original citation format if cursor is missing @@ -693,4 +693,3 @@ def extract_domain(url): new_content += old_content[last_idx:] return new_content, annotations, has_partial_citations - diff --git a/gpt_oss/tools/tool.py b/gpt_oss/tools/tool.py index 210a3d1..374feb9 100644 --- a/gpt_oss/tools/tool.py +++ b/gpt_oss/tools/tool.py @@ -1,13 +1,8 @@ from abc import ABC, abstractmethod -from uuid import UUID, uuid4 from typing import AsyncIterator +from uuid import UUID, uuid4 -from openai_harmony import ( - Author, - Role, - Message, - TextContent, -) +from openai_harmony import Author, Message, Role, TextContent def _maybe_update_inplace_and_validate_channel( @@ -63,13 +58,17 @@ async def process(self, message: Message) -> AsyncIterator[Message]: """ async for m in self._process(message): if self.output_channel_should_match_input_channel: - _maybe_update_inplace_and_validate_channel(input_message=message, tool_message=m) + _maybe_update_inplace_and_validate_channel( + input_message=message, tool_message=m + ) yield m @abstractmethod async def _process(self, message: Message) -> AsyncIterator[Message]: """Override this method to provide the implementation of the tool.""" - if False: # This is to convince the type checker that this is an async generator. + if ( + False + ): # This is to convince the type checker that this is an async generator. yield # type: ignore[unreachable] _ = message # Stifle "unused argument" warning. raise NotImplementedError @@ -94,7 +93,6 @@ def error_message( return Message( id=id if id else uuid4(), author=Author(role=Role.TOOL, name=self.name), - content=TextContent(text=error_message), # TODO: Use SystemError instead + content=TextContent(text=f"Error: {error_message}"), channel=channel, ).with_recipient("assistant") - diff --git a/gpt_oss/torch/model.py b/gpt_oss/torch/model.py index 9180d49..3c473e0 100644 --- a/gpt_oss/torch/model.py +++ b/gpt_oss/torch/model.py @@ -420,8 +420,10 @@ def from_checkpoint( if "mlp1" in name: # both weight and bias loaded_tensor = loaded_tensor[ :, - my_rank * 2 - * per_rank_intermediate_size : (my_rank + 1) * 2 + my_rank + * 2 + * per_rank_intermediate_size : (my_rank + 1) + * 2 * per_rank_intermediate_size, ..., ] @@ -448,16 +450,20 @@ def __init__(self, checkpoint: str, device: torch.device): self.model = Transformer.from_checkpoint(checkpoint, device=self.device) @torch.inference_mode() - def generate(self, - prompt_tokens: list[int], - stop_tokens: list[int], - temperature: float = 1.0, - max_tokens: int = 0, - return_logprobs: bool = False): + def generate( + self, + prompt_tokens: list[int], + stop_tokens: list[int], + temperature: float = 1.0, + max_tokens: int = 0, + return_logprobs: bool = False, + ): tokens = list(prompt_tokens) num_generated_tokens = 0 while max_tokens == 0 or num_generated_tokens < max_tokens: - logits = self.model(torch.as_tensor(tokens, dtype=torch.int32, device=self.device))[-1] + logits = self.model( + torch.as_tensor(tokens, dtype=torch.int32, device=self.device) + )[-1] if temperature == 0.0: predicted_token = torch.argmax(logits, dim=-1).item() else: diff --git a/gpt_oss/torch/utils.py b/gpt_oss/torch/utils.py index ce87a85..10ce783 100644 --- a/gpt_oss/torch/utils.py +++ b/gpt_oss/torch/utils.py @@ -1,4 +1,5 @@ import os + import torch import torch.distributed as dist @@ -6,10 +7,11 @@ def suppress_output(rank): """Suppress printing on the current device. Force printing with `force=True`.""" import builtins as __builtin__ + builtin_print = __builtin__.print def print(*args, **kwargs): - force = kwargs.pop('force', False) + force = kwargs.pop("force", False) if force: builtin_print("rank #%d:" % rank, *args, **kwargs) elif rank == 0: diff --git a/gpt_oss/torch/weights.py b/gpt_oss/torch/weights.py index aa5df58..1fd6867 100644 --- a/gpt_oss/torch/weights.py +++ b/gpt_oss/torch/weights.py @@ -4,25 +4,47 @@ import torch from safetensors import safe_open - # Bytes per MXFP4 block: 32 FP4 numbers packed in 16 bytes BYTES_PER_BLOCK = 16 FP4_VALUES = [ - +0.0, +0.5, +1.0, +1.5, +2.0, +3.0, +4.0, +6.0, - -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, + +0.0, + +0.5, + +1.0, + +1.5, + +2.0, + +3.0, + +4.0, + +6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, ] # Map the names assumed in this implementation to the checkpoint names. -PARAM_NAME_MAP = { - f"block.{n}.mlp.mlp1_bias": f"block.{n}.mlp.mlp1_bias" for n in range(36) -} | { - f"block.{n}.mlp.mlp1_weight": (f"block.{n}.mlp.mlp1_weight.blocks", f"block.{n}.mlp.mlp1_weight.scales") for n in range(36) -} | { - f"block.{n}.mlp.mlp2_bias": f"block.{n}.mlp.mlp2_bias" for n in range(36) -} | { - f"block.{n}.mlp.mlp2_weight": (f"block.{n}.mlp.mlp2_weight.blocks", f"block.{n}.mlp.mlp2_weight.scales") for n in range(36) -} +PARAM_NAME_MAP = ( + {f"block.{n}.mlp.mlp1_bias": f"block.{n}.mlp.mlp1_bias" for n in range(36)} + | { + f"block.{n}.mlp.mlp1_weight": ( + f"block.{n}.mlp.mlp1_weight.blocks", + f"block.{n}.mlp.mlp1_weight.scales", + ) + for n in range(36) + } + | {f"block.{n}.mlp.mlp2_bias": f"block.{n}.mlp.mlp2_bias" for n in range(36)} + | { + f"block.{n}.mlp.mlp2_weight": ( + f"block.{n}.mlp.mlp2_weight.blocks", + f"block.{n}.mlp.mlp2_weight.scales", + ) + for n in range(36) + } +) class Checkpoint: @@ -53,13 +75,17 @@ def get(self, name: str) -> torch.Tensor: match PARAM_NAME_MAP.get(name, name): case (blocks_name, scales_name): # MoE weights: are in block-based MXFP4 format - return self._get_mxfp4_tensor(blocks_name, scales_name, dtype=torch.bfloat16) + return self._get_mxfp4_tensor( + blocks_name, scales_name, dtype=torch.bfloat16 + ) case tensor_name: # MoE biases and other weights return self._get_tensor(tensor_name) def _get_tensor(self, name: str) -> str: - assert name in self.tensor_name_to_file, f"Tensor {name} not found in checkpoint." + assert ( + name in self.tensor_name_to_file + ), f"Tensor {name} not found in checkpoint." with safe_open( self.tensor_name_to_file[name], framework="pt", device=self.device_str ) as f: @@ -73,24 +99,24 @@ def _get_mxfp4_tensor( dtype: torch.dtype = torch.bfloat16, rows_per_chunk: int = 16384 * 512, ) -> torch.Tensor: - assert blocks_name in self.tensor_name_to_file, ( - f"Blocks tensor {blocks_name} not found in checkpoint." - ) - assert scales_name in self.tensor_name_to_file, ( - f"Scales tensor {scales_name} not found in checkpoint." - ) + assert ( + blocks_name in self.tensor_name_to_file + ), f"Blocks tensor {blocks_name} not found in checkpoint." + assert ( + scales_name in self.tensor_name_to_file + ), f"Scales tensor {scales_name} not found in checkpoint." blocks = self._get_tensor(blocks_name) scales = self._get_tensor(scales_name).to(torch.int32) - 127 - assert blocks.shape[:-1] == scales.shape, ( - f"{blocks.shape=} does not match {scales.shape=}" - ) + assert ( + blocks.shape[:-1] == scales.shape + ), f"{blocks.shape=} does not match {scales.shape=}" lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device) *prefix_shape, G, B = blocks.shape - rows_total = math.prod(prefix_shape) * G + rows_total = math.prod(prefix_shape) * G blocks = blocks.reshape(rows_total, B) scales = scales.reshape(rows_total, 1) @@ -116,7 +142,9 @@ def _get_mxfp4_tensor( return out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) - def _get_mxfp4_tensor_copy(self, blocks_name: str, scales_name: str, dtype: torch.dtype = torch.bfloat16): + def _get_mxfp4_tensor_copy( + self, blocks_name: str, scales_name: str, dtype: torch.dtype = torch.bfloat16 + ): "short version that uses a lot of memory" loaded_blocks = self._get_tensor(blocks_name) @@ -124,7 +152,9 @@ def _get_mxfp4_tensor_copy(self, blocks_name: str, scales_name: str, dtype: torc loaded_blocks_lo = loaded_blocks & 0x0F loaded_blocks_hi = loaded_blocks >> 4 loaded_blocks = torch.stack((loaded_blocks_lo, loaded_blocks_hi), dim=-1) - loaded_blocks = loaded_blocks.view(*loaded_blocks.shape[:-2], loaded_blocks.shape[-2] * 2) + loaded_blocks = loaded_blocks.view( + *loaded_blocks.shape[:-2], loaded_blocks.shape[-2] * 2 + ) loaded_scales = self._get_tensor(scales_name) # Upcast to int32 and subtract bias @@ -132,6 +162,8 @@ def _get_mxfp4_tensor_copy(self, blocks_name: str, scales_name: str, dtype: torc # Convert MXFP4 numbers into target dtype fp4_values = torch.tensor(FP4_VALUES, dtype=dtype, device=self.device_str) - loaded_tensor = torch.ldexp(fp4_values[loaded_blocks.int()], loaded_scales.unsqueeze(-1)) + loaded_tensor = torch.ldexp( + fp4_values[loaded_blocks.int()], loaded_scales.unsqueeze(-1) + ) loaded_tensor = loaded_tensor.view(*loaded_tensor.shape[:-2], -1) return loaded_tensor diff --git a/gpt_oss/triton/attention.py b/gpt_oss/triton/attention.py index e9222f0..8f7b845 100644 --- a/gpt_oss/triton/attention.py +++ b/gpt_oss/triton/attention.py @@ -8,7 +8,6 @@ import pytest import torch - import triton import triton.language as tl @@ -111,7 +110,10 @@ def _attn_fwd( q = tl.load(Q_block_ptr) if BANDWIDTH: - lo, hi = tl.maximum(start_q, start_q + start_m * BLOCK_M - BANDWIDTH), (start_q + start_m + 1) * BLOCK_M + lo, hi = ( + tl.maximum(start_q, start_q + start_m * BLOCK_M - BANDWIDTH), + (start_q + start_m + 1) * BLOCK_M, + ) else: lo, hi = start_q, (start_q + start_m + 1) * BLOCK_M @@ -125,7 +127,9 @@ def _attn_fwd( mask = (start_n + offs_n)[None, :] > (start_q + offs_m)[:, None] if BANDWIDTH: - too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1) + too_old = (start_n + offs_n[None, :]) < ( + start_q + offs_m[:, None] - BANDWIDTH + 1 + ) mask = mask | too_old k = tl.load(K_block_ptr) @@ -186,7 +190,9 @@ def forward(ctx, q, k, v, sinks, sm_scale, bandwidth, start_q): v = torch.nn.functional.pad(v, (0, 0, 0, n_pad_size)) o = torch.empty_like(q) - M = torch.empty((bs, n_heads, n_ctx + m_pad_size), device=q.device, dtype=torch.float32) + M = torch.empty( + (bs, n_heads, n_ctx + m_pad_size), device=q.device, dtype=torch.float32 + ) grid = (triton.cdiv(n_ctx, BLOCK_M), bs * n_heads, 1) _attn_fwd[grid]( q, @@ -244,7 +250,9 @@ def attention_ref( sliding_window: int | None = None, start_q: torch.LongTensor = 0, ): - batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape + batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = ( + query.shape + ) batch_size, num_keys, num_key_value_heads, head_dim = key.shape sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float() @@ -272,7 +280,9 @@ def attention_ref( output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) - output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups * head_dim).bfloat16() + output = output.reshape( + batch_size, num_queries, num_key_value_heads * num_key_value_groups * head_dim + ).bfloat16() return output @@ -285,13 +295,37 @@ def attention_ref( @pytest.mark.parametrize("sm_scale", [0.125]) @pytest.mark.parametrize("sliding_window", [None, 128]) @pytest.mark.parametrize("start_q", [0, 5]) -def test_eq(batch_size, num_queries, num_keys, num_key_value_heads, num_key_value_groups, head_dim, sm_scale, sliding_window, start_q): +def test_eq( + batch_size, + num_queries, + num_keys, + num_key_value_heads, + num_key_value_groups, + head_dim, + sm_scale, + sliding_window, + start_q, +): if num_queries > num_keys: pytest.skip("too many queries") - q = torch.randn(batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim).bfloat16().cuda() - k = torch.randn(batch_size, num_keys, num_key_value_heads, head_dim).bfloat16().cuda() - v = torch.randn(batch_size, num_keys, num_key_value_heads, head_dim).bfloat16().cuda() + q = ( + torch.randn( + batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim + ) + .bfloat16() + .cuda() + ) + k = ( + torch.randn(batch_size, num_keys, num_key_value_heads, head_dim) + .bfloat16() + .cuda() + ) + v = ( + torch.randn(batch_size, num_keys, num_key_value_heads, head_dim) + .bfloat16() + .cuda() + ) sinks = torch.randn(num_key_value_heads * num_key_value_groups).bfloat16().cuda() start_q = torch.tensor([start_q], dtype=torch.int32).cuda() @@ -299,4 +333,4 @@ def test_eq(batch_size, num_queries, num_keys, num_key_value_heads, num_key_valu o1 = attention(q, k, v, sinks, sm_scale, sliding_window, start_q) o2 = attention_ref(q, k, v, sinks, sm_scale, sliding_window, start_q) - torch.testing.assert_close(o1, o2) \ No newline at end of file + torch.testing.assert_close(o1, o2) diff --git a/gpt_oss/triton/model.py b/gpt_oss/triton/model.py index 2e14478..008fbb2 100644 --- a/gpt_oss/triton/model.py +++ b/gpt_oss/triton/model.py @@ -8,7 +8,7 @@ from gpt_oss.torch.model import ModelConfig, RMSNorm from gpt_oss.torch.weights import Checkpoint from gpt_oss.triton.attention import attention, attention_ref -from gpt_oss.triton.moe import quantize_mx4, moe +from gpt_oss.triton.moe import moe, quantize_mx4 class RotaryEmbedding(torch.nn.Module): @@ -78,7 +78,9 @@ def _compute_concentration_and_inv_freq(self) -> torch.Tensor: def _compute_cos_sin(self, start: int, num_tokens: int): concentration, inv_freq = self._compute_concentration_and_inv_freq() - t = torch.arange(start, start + num_tokens, dtype=torch.float32, device=self.device) + t = torch.arange( + start, start + num_tokens, dtype=torch.float32, device=self.device + ) freqs = torch.einsum("i,j->ij", t, inv_freq) cos = freqs.cos() * concentration sin = freqs.sin() * concentration @@ -119,9 +121,20 @@ def forward( class Cache: - def __init__(self, batch_size, n_ctx, n_kv_heads, d_head=64, device: torch.device | None = None): - self.k = torch.zeros((batch_size, n_ctx, n_kv_heads, d_head), dtype=torch.bfloat16, device=device) - self.v = torch.zeros((batch_size, n_ctx, n_kv_heads, d_head), dtype=torch.bfloat16, device=device) + def __init__( + self, + batch_size, + n_ctx, + n_kv_heads, + d_head=64, + device: torch.device | None = None, + ): + self.k = torch.zeros( + (batch_size, n_ctx, n_kv_heads, d_head), dtype=torch.bfloat16, device=device + ) + self.v = torch.zeros( + (batch_size, n_ctx, n_kv_heads, d_head), dtype=torch.bfloat16, device=device + ) self.offset = torch.zeros((1,), dtype=torch.long, device=device) def reset(self): @@ -147,7 +160,9 @@ def truncate(self, n_ctx): def extend(self, k, v): batch_size, n_ctx, *_rest = k.shape assert batch_size == self.k.shape[0] - indices = torch.arange(0, n_ctx, device=k.device, dtype=torch.long) + self.offset + indices = ( + torch.arange(0, n_ctx, device=k.device, dtype=torch.long) + self.offset + ) self.k.index_copy_(1, indices, k) self.v.index_copy_(1, indices, v) self.offset.add_(n_ctx) @@ -206,7 +221,7 @@ def forward(self, x: torch.Tensor, cache: Cache | None = None) -> torch.Tensor: qkv_parts = ( self.num_attention_heads * self.head_dim, self.num_key_value_heads * self.head_dim, - self.num_key_value_heads * self.head_dim + self.num_key_value_heads * self.head_dim, ) q, k, v = torch.split(qkv, qkv_parts, dim=-1) q, k, v = q.contiguous(), k.contiguous(), v.contiguous() @@ -283,22 +298,24 @@ def __init__( self.experts_per_token = config.experts_per_token self.swiglu_limit = config.swiglu_limit self.norm = RMSNorm(config.hidden_size, device=device) - self.gate = torch.nn.ParameterDict({ - "weight": torch.nn.Parameter( - torch.empty( - (config.hidden_size, config.num_experts), - device=device, - dtype=torch.bfloat16, - ) - ), - "bias": torch.nn.Parameter( - torch.empty( - (config.num_experts,), - device=device, - dtype=torch.bfloat16, - ) - ), - }) + self.gate = torch.nn.ParameterDict( + { + "weight": torch.nn.Parameter( + torch.empty( + (config.hidden_size, config.num_experts), + device=device, + dtype=torch.bfloat16, + ) + ), + "bias": torch.nn.Parameter( + torch.empty( + (config.num_experts,), + device=device, + dtype=torch.bfloat16, + ) + ), + } + ) self.mlp1_weight_tensor, self.mlp1_weight_mx = quantize_mx4( torch.empty( ( @@ -310,7 +327,9 @@ def __init__( dtype=torch.bfloat16, ), ) - self.mlp1_weight = torch.nn.Parameter(self.mlp1_weight_tensor.storage.data, requires_grad=False) + self.mlp1_weight = torch.nn.Parameter( + self.mlp1_weight_tensor.storage.data, requires_grad=False + ) self.mlp1_bias = torch.nn.Parameter( torch.empty( (config.num_experts, config.intermediate_size * 2), @@ -329,7 +348,9 @@ def __init__( dtype=torch.bfloat16, ), ) - self.mlp2_weight = torch.nn.Parameter(self.mlp2_weight_tensor.storage.data, requires_grad=False) + self.mlp2_weight = torch.nn.Parameter( + self.mlp2_weight_tensor.storage.data, requires_grad=False + ) self.mlp2_bias = torch.nn.Parameter( torch.empty( (config.num_experts, config.hidden_size), @@ -347,8 +368,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: t = moe( t, self.gate["weight"], - self.mlp1_weight_tensor, self.mlp1_weight_mx, - self.mlp2_weight_tensor, self.mlp2_weight_mx, + self.mlp1_weight_tensor, + self.mlp1_weight_mx, + self.mlp2_weight_tensor, + self.mlp2_weight_mx, self.gate["bias"].float(), self.mlp1_bias.float(), self.mlp2_bias.float(), @@ -405,8 +428,10 @@ def __init__( dtype=torch.bfloat16, ) - def forward(self, x: torch.Tensor, caches: list[Cache] | None = None) -> torch.Tensor: - caches=caches or [None] * len(self.block) + def forward( + self, x: torch.Tensor, caches: list[Cache] | None = None + ) -> torch.Tensor: + caches = caches or [None] * len(self.block) with record_function("embedding"): x = self.embedding(x) for block, cache in zip(self.block, caches): @@ -420,7 +445,9 @@ def forward(self, x: torch.Tensor, caches: list[Cache] | None = None) -> torch.T @staticmethod def from_checkpoint( - path: str, config: ModelConfig | None = None, device: str | torch.device = "cuda", + path: str, + config: ModelConfig | None = None, + device: str | torch.device = "cuda", ) -> "Transformer": if not isinstance(device, torch.device): device = torch.device(device) @@ -472,7 +499,10 @@ class TokenGenerator: def __init__(self, checkpoint: str, context: int, device: torch.device): self.device = device self.model = Transformer.from_checkpoint(checkpoint, device=self.device) - self.caches = [Cache(1, context, self.model.config.num_key_value_heads, device=self.device) for _ in range(len(self.model.block))] + self.caches = [ + Cache(1, context, self.model.config.num_key_value_heads, device=self.device) + for _ in range(len(self.model.block)) + ] self.input_token = torch.zeros(1, dtype=torch.int32, device=self.device) # warmup self.model(self.input_token[None, :], caches=self.caches) @@ -482,16 +512,20 @@ def __init__(self, checkpoint: str, context: int, device: torch.device): self.logits = self.model(self.input_token[None, :], caches=self.caches)[0] @torch.inference_mode() - def generate(self, - prompt_tokens: list[int], - stop_tokens: list[int] | None = None, - temperature: float = 1.0, - max_tokens: int = 0, - return_logprobs: bool = False): + def generate( + self, + prompt_tokens: list[int], + stop_tokens: list[int] | None = None, + temperature: float = 1.0, + max_tokens: int = 0, + return_logprobs: bool = False, + ): stop_tokens = stop_tokens or [] for cache in self.caches: cache.reset() - prompt_tokens = torch.as_tensor(prompt_tokens, dtype=torch.int32, device=self.device) + prompt_tokens = torch.as_tensor( + prompt_tokens, dtype=torch.int32, device=self.device + ) self.model(prompt_tokens[None, :-1], self.caches) predicted_token = prompt_tokens[-1] num_generated_tokens = 0 diff --git a/gpt_oss/triton/moe.py b/gpt_oss/triton/moe.py index 925dbd5..de32c72 100644 --- a/gpt_oss/triton/moe.py +++ b/gpt_oss/triton/moe.py @@ -1,16 +1,16 @@ import torch -from torch.profiler import record_function - import triton_kernels import triton_kernels.swiglu -from triton_kernels.numerics_details.mxfp import downcast_to_mxfp -from triton_kernels.matmul_ogs import PrecisionConfig, FlexCtx, FnSpecs, FusedActivation -from triton_kernels.matmul_ogs import matmul_ogs +from torch.profiler import record_function +from triton_kernels.matmul_ogs import (FlexCtx, FnSpecs, FusedActivation, + PrecisionConfig, matmul_ogs) from triton_kernels.numerics import InFlexData +from triton_kernels.numerics_details.mxfp import downcast_to_mxfp from triton_kernels.routing import routing -from triton_kernels.tensor import convert_layout -from triton_kernels.tensor_details.layout import StridedLayout, HopperMXScaleLayout, HopperMXValueLayout -from triton_kernels.tensor import wrap_torch_tensor, FP4 +from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor +from triton_kernels.tensor_details.layout import (HopperMXScaleLayout, + HopperMXValueLayout, + StridedLayout) def quantize_mx4(w): @@ -31,7 +31,22 @@ def swiglu(x, alpha: float = 1.702, limit: float = 7.0, interleaved: bool = True return out_glu * (x_linear + 1) -def moe(x, wg, w1, w1_mx, w2, w2_mx, bg, b1, b2, experts_per_token=4, num_experts=128, swiglu_limit=7.0, fused_act=True, interleaved=True): +def moe( + x, + wg, + w1, + w1_mx, + w2, + w2_mx, + bg, + b1, + b2, + experts_per_token=4, + num_experts=128, + swiglu_limit=7.0, + fused_act=True, + interleaved=True, +): if x.numel() == 0: return x @@ -42,19 +57,43 @@ def moe(x, wg, w1, w1_mx, w2, w2_mx, bg, b1, b2, experts_per_token=4, num_expert with record_function("wg"): logits = matmul_ogs(x, wg, bg, precision_config=pcg) with record_function("routing"): - rdata, gather_indx, scatter_indx = routing(logits, experts_per_token, simulated_ep=1) + rdata, gather_indx, scatter_indx = routing( + logits, experts_per_token, simulated_ep=1 + ) if fused_act: assert interleaved, "Fused activation requires interleaved weights" with record_function("w1+swiglu"): - act = FusedActivation(FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), (1.702, swiglu_limit), 2) - x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act) + act = FusedActivation( + FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), + (1.702, swiglu_limit), + 2, + ) + x = matmul_ogs( + x, + w1, + b1, + rdata, + gather_indx=gather_indx, + precision_config=pc1, + fused_activation=act, + ) else: with record_function("w1"): - x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1) + x = matmul_ogs( + x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1 + ) with record_function("swiglu"): x = swiglu(x, limit=swiglu_limit, interleaved=interleaved) with record_function("w2"): - x = matmul_ogs(x, w2, b2, rdata, scatter_indx=scatter_indx, precision_config=pc2, gammas=rdata.gate_scal) + x = matmul_ogs( + x, + w2, + b2, + rdata, + scatter_indx=scatter_indx, + precision_config=pc2, + gammas=rdata.gate_scal, + ) return x diff --git a/gpt_oss/vllm/token_generator.py b/gpt_oss/vllm/token_generator.py index 000f322..8b65a76 100644 --- a/gpt_oss/vllm/token_generator.py +++ b/gpt_oss/vllm/token_generator.py @@ -1,4 +1,4 @@ -from vllm import LLMEngine, EngineArgs, SamplingParams, TokensPrompt +from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt class TokenGenerator: @@ -10,20 +10,24 @@ def __init__(self, model_path: str, tensor_parallel_size: int = 1): self.engine = LLMEngine.from_engine_args(args) self.request_id = 0 - def generate(self, - prompt_tokens: list[int], - stop_tokens: list[int] | None = None, - temperature: float = 1.0, - max_tokens: int = 0, - return_logprobs: bool = False): + def generate( + self, + prompt_tokens: list[int], + stop_tokens: list[int] | None = None, + temperature: float = 1.0, + max_tokens: int = 0, + return_logprobs: bool = False, + ): if max_tokens == 0: max_tokens = None request_id = str(self.request_id) self.request_id += 1 - sampling_params = SamplingParams(temperature=temperature, - max_tokens=max_tokens, - stop_token_ids=stop_tokens, - logprobs=0 if return_logprobs else None) + sampling_params = SamplingParams( + temperature=temperature, + max_tokens=max_tokens, + stop_token_ids=stop_tokens, + logprobs=0 if return_logprobs else None, + ) prompt = TokensPrompt(prompt_token_ids=prompt_tokens) self.engine.add_request(request_id, prompt, sampling_params) last_token_id = [] @@ -32,8 +36,12 @@ def generate(self, output = step_outputs[0].outputs[0] token_ids = output.token_ids logprobs_list = output.logprobs if hasattr(output, "logprobs") else None - new_token_ids = token_ids[len(last_token_id):] - new_logprobs = logprobs_list[len(last_token_id):] if logprobs_list is not None else [None] * len(new_token_ids) + new_token_ids = token_ids[len(last_token_id) :] + new_logprobs = ( + logprobs_list[len(last_token_id) :] + if logprobs_list is not None + else [None] * len(new_token_ids) + ) for token_id, logprobs in zip(new_token_ids, new_logprobs): last_token_id.append(token_id) if return_logprobs: diff --git a/tests/conftest.py b/tests/conftest.py index 4c008a3..47f4bfb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,16 +1,15 @@ import os import sys +from typing import Any, Generator +from unittest.mock import MagicMock, Mock + import pytest -from typing import Generator, Any -from unittest.mock import Mock, MagicMock from fastapi.testclient import TestClient -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from openai_harmony import HarmonyEncodingName, load_harmony_encoding -from openai_harmony import ( - HarmonyEncodingName, - load_harmony_encoding, -) from gpt_oss.responses_api.api_server import create_api_server @@ -22,24 +21,25 @@ def harmony_encoding(): @pytest.fixture def mock_infer_token(harmony_encoding): fake_tokens = harmony_encoding.encode( - "<|channel|>final<|message|>Test response<|return|>", - allowed_special="all" + "<|channel|>final<|message|>Test response<|return|>", allowed_special="all" ) token_queue = fake_tokens.copy() - - def _mock_infer(tokens: list[int], temperature: float = 0.0, new_request: bool = False) -> int: + + def _mock_infer( + tokens: list[int], temperature: float = 0.0, new_request: bool = False + ) -> int: nonlocal token_queue if len(token_queue) == 0: token_queue = fake_tokens.copy() return token_queue.pop(0) + return _mock_infer @pytest.fixture def api_client(harmony_encoding, mock_infer_token) -> Generator[TestClient, None, None]: app = create_api_server( - infer_next_token=mock_infer_token, - encoding=harmony_encoding + infer_next_token=mock_infer_token, encoding=harmony_encoding ) with TestClient(app) as client: yield client @@ -53,7 +53,7 @@ def sample_request_data(): "stream": False, "reasoning_effort": "low", "temperature": 0.7, - "tools": [] + "tools": [], } @@ -72,23 +72,23 @@ def mock_python_tool(): mock.execute.return_value = { "output": "print('Hello')", "error": None, - "exit_code": 0 + "exit_code": 0, } return mock @pytest.fixture(autouse=True) def reset_test_environment(): - test_env_vars = ['OPENAI_API_KEY', 'GPT_OSS_MODEL_PATH'] + test_env_vars = ["OPENAI_API_KEY", "GPT_OSS_MODEL_PATH"] original_values = {} - + for var in test_env_vars: if var in os.environ: original_values[var] = os.environ[var] del os.environ[var] - + yield - + for var, value in original_values.items(): os.environ[var] = value @@ -96,23 +96,23 @@ def reset_test_environment(): @pytest.fixture def performance_timer(): import time - + class Timer: def __init__(self): self.start_time = None self.end_time = None - + def start(self): self.start_time = time.time() - + def stop(self): self.end_time = time.time() return self.elapsed - + @property def elapsed(self): if self.start_time and self.end_time: return self.end_time - self.start_time return None - - return Timer() \ No newline at end of file + + return Timer() diff --git a/tests/test_api_endpoints.py b/tests/test_api_endpoints.py index 7fd354b..fb27eba 100644 --- a/tests/test_api_endpoints.py +++ b/tests/test_api_endpoints.py @@ -1,12 +1,13 @@ -import pytest -import json import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest from fastapi import status -from unittest.mock import patch, MagicMock, AsyncMock class TestResponsesEndpoint: - + def test_basic_response_creation(self, api_client, sample_request_data): response = api_client.post("/v1/responses", json=sample_request_data) assert response.status_code == status.HTTP_200_OK @@ -14,7 +15,7 @@ def test_basic_response_creation(self, api_client, sample_request_data): assert "id" in data assert data["object"] == "response" assert data["model"] == sample_request_data["model"] - + def test_response_with_high_reasoning(self, api_client, sample_request_data): sample_request_data["reasoning_effort"] = "high" response = api_client.post("/v1/responses", json=sample_request_data) @@ -22,7 +23,7 @@ def test_response_with_high_reasoning(self, api_client, sample_request_data): data = response.json() assert "id" in data assert data["status"] == "completed" - + def test_response_with_medium_reasoning(self, api_client, sample_request_data): sample_request_data["reasoning_effort"] = "medium" response = api_client.post("/v1/responses", json=sample_request_data) @@ -30,27 +31,23 @@ def test_response_with_medium_reasoning(self, api_client, sample_request_data): data = response.json() assert "id" in data assert data["status"] == "completed" - + def test_response_with_invalid_model(self, api_client, sample_request_data): sample_request_data["model"] = "invalid-model" response = api_client.post("/v1/responses", json=sample_request_data) # Should still accept but might handle differently assert response.status_code == status.HTTP_200_OK - + def test_response_with_empty_input(self, api_client, sample_request_data): sample_request_data["input"] = "" response = api_client.post("/v1/responses", json=sample_request_data) assert response.status_code == status.HTTP_200_OK - + def test_response_with_tools(self, api_client, sample_request_data): - sample_request_data["tools"] = [ - { - "type": "browser_search" - } - ] + sample_request_data["tools"] = [{"type": "browser_search"}] response = api_client.post("/v1/responses", json=sample_request_data) assert response.status_code == status.HTTP_200_OK - + def test_response_with_custom_temperature(self, api_client, sample_request_data): for temp in [0.0, 0.5, 1.0, 1.5, 2.0]: sample_request_data["temperature"] = temp @@ -58,10 +55,12 @@ def test_response_with_custom_temperature(self, api_client, sample_request_data) assert response.status_code == status.HTTP_200_OK data = response.json() assert "usage" in data - + def test_streaming_response(self, api_client, sample_request_data): sample_request_data["stream"] = True - with api_client.stream("POST", "/v1/responses", json=sample_request_data) as response: + with api_client.stream( + "POST", "/v1/responses", json=sample_request_data + ) as response: assert response.status_code == status.HTTP_200_OK # Verify we get SSE events for line in response.iter_lines(): @@ -73,63 +72,66 @@ def test_streaming_response(self, api_client, sample_request_data): class TestResponsesWithSession: - + def test_response_with_session_id(self, api_client, sample_request_data): session_id = "test-session-123" sample_request_data["session_id"] = session_id - + # First request response1 = api_client.post("/v1/responses", json=sample_request_data) assert response1.status_code == status.HTTP_200_OK data1 = response1.json() - + # Second request with same session sample_request_data["input"] = "Follow up question" response2 = api_client.post("/v1/responses", json=sample_request_data) assert response2.status_code == status.HTTP_200_OK data2 = response2.json() - + # Should have different response IDs assert data1["id"] != data2["id"] - + def test_response_continuation(self, api_client, sample_request_data): # Create initial response response1 = api_client.post("/v1/responses", json=sample_request_data) assert response1.status_code == status.HTTP_200_OK data1 = response1.json() response_id = data1["id"] - + # Continue the response continuation_request = { "model": sample_request_data["model"], "response_id": response_id, - "input": "Continue the previous thought" + "input": "Continue the previous thought", } response2 = api_client.post("/v1/responses", json=continuation_request) assert response2.status_code == status.HTTP_200_OK class TestErrorHandling: - + def test_missing_required_fields(self, api_client): # Model field has default, so test with empty JSON response = api_client.post("/v1/responses", json={}) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - + def test_invalid_reasoning_effort(self, api_client, sample_request_data): sample_request_data["reasoning_effort"] = "invalid" response = api_client.post("/v1/responses", json=sample_request_data) # May handle gracefully or return error - assert response.status_code in [status.HTTP_200_OK, status.HTTP_422_UNPROCESSABLE_ENTITY] - + assert response.status_code in [ + status.HTTP_200_OK, + status.HTTP_422_UNPROCESSABLE_ENTITY, + ] + def test_malformed_json(self, api_client): response = api_client.post( "/v1/responses", data="not json", - headers={"Content-Type": "application/json"} + headers={"Content-Type": "application/json"}, ) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - + def test_extremely_long_input(self, api_client, sample_request_data): # Test with very long input sample_request_data["input"] = "x" * 100000 @@ -138,55 +140,51 @@ def test_extremely_long_input(self, api_client, sample_request_data): class TestToolIntegration: - + def test_browser_search_tool(self, api_client, sample_request_data): - sample_request_data["tools"] = [ - { - "type": "browser_search" - } - ] + sample_request_data["tools"] = [{"type": "browser_search"}] response = api_client.post("/v1/responses", json=sample_request_data) assert response.status_code == status.HTTP_200_OK - + def test_function_tool_integration(self, api_client, sample_request_data): sample_request_data["tools"] = [ { "type": "function", "name": "test_function", "parameters": {"type": "object", "properties": {}}, - "description": "Test function" + "description": "Test function", } ] response = api_client.post("/v1/responses", json=sample_request_data) assert response.status_code == status.HTTP_200_OK - + def test_multiple_tools(self, api_client, sample_request_data): sample_request_data["tools"] = [ - { - "type": "browser_search" - }, + {"type": "browser_search"}, { "type": "function", "name": "test_function", "parameters": {"type": "object", "properties": {}}, - "description": "Test function" - } + "description": "Test function", + }, ] response = api_client.post("/v1/responses", json=sample_request_data) assert response.status_code == status.HTTP_200_OK class TestPerformance: - - def test_response_time_under_threshold(self, api_client, sample_request_data, performance_timer): + + def test_response_time_under_threshold( + self, api_client, sample_request_data, performance_timer + ): performance_timer.start() response = api_client.post("/v1/responses", json=sample_request_data) elapsed = performance_timer.stop() - + assert response.status_code == status.HTTP_200_OK # Response should be reasonably fast for mock inference assert elapsed < 5.0 # 5 seconds threshold - + def test_multiple_sequential_requests(self, api_client, sample_request_data): # Test multiple requests work correctly for i in range(3): @@ -197,12 +195,12 @@ def test_multiple_sequential_requests(self, api_client, sample_request_data): class TestUsageTracking: - + def test_usage_object_structure(self, api_client, sample_request_data): response = api_client.post("/v1/responses", json=sample_request_data) assert response.status_code == status.HTTP_200_OK data = response.json() - + assert "usage" in data usage = data["usage"] assert "input_tokens" in usage @@ -210,21 +208,21 @@ def test_usage_object_structure(self, api_client, sample_request_data): assert "total_tokens" in usage # reasoning_tokens may not always be present # assert "reasoning_tokens" in usage - + # Basic validation assert usage["input_tokens"] >= 0 assert usage["output_tokens"] >= 0 assert usage["total_tokens"] == usage["input_tokens"] + usage["output_tokens"] - + def test_usage_increases_with_longer_input(self, api_client, sample_request_data): # Short input response1 = api_client.post("/v1/responses", json=sample_request_data) usage1 = response1.json()["usage"] - + # Longer input sample_request_data["input"] = sample_request_data["input"] * 10 response2 = api_client.post("/v1/responses", json=sample_request_data) usage2 = response2.json()["usage"] - + # Longer input should use more tokens - assert usage2["input_tokens"] > usage1["input_tokens"] \ No newline at end of file + assert usage2["input_tokens"] > usage1["input_tokens"] diff --git a/tests/test_responses_api.py b/tests/test_responses_api.py index c3caec3..33de68f 100644 --- a/tests/test_responses_api.py +++ b/tests/test_responses_api.py @@ -2,10 +2,7 @@ import pytest from fastapi.testclient import TestClient -from openai_harmony import ( - HarmonyEncodingName, - load_harmony_encoding, -) +from openai_harmony import HarmonyEncodingName, load_harmony_encoding from gpt_oss.responses_api.api_server import create_api_server