Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 16 additions & 86 deletions src/modelgauge/suts/nvidia_nim_api_client.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
from typing import Any, Dict, List, Optional, Union
from typing import Optional

from openai import OpenAI
from openai import APITimeoutError, ConflictError, InternalServerError, RateLimitError
from openai.types.chat import ChatCompletion
from pydantic import BaseModel

from modelgauge.prompt import ChatPrompt, ChatRole, TextPrompt
from modelgauge.retry_decorator import retry
from modelgauge.secret_values import (
InjectSecret,
RequiredSecret,
SecretDescription,
)
from modelgauge.sut import PromptResponseSUT, SUTResponse
from modelgauge.suts.openai_client import OpenAIChat, OpenAIChatRequest
from modelgauge.model_options import ModelOptions
from modelgauge.sut_capabilities import (
AcceptsChatPrompt,
Expand All @@ -21,16 +14,8 @@
from modelgauge.sut_decorator import modelgauge_sut
from modelgauge.sut_registry import SUTS

_SYSTEM_ROLE = "system"
_USER_ROLE = "user"
_ASSISTANT_ROLE = "assistant"
_TOOL_ROLE = "tool_call_id"

_ROLE_MAP = {
ChatRole.user: _USER_ROLE,
ChatRole.sut: _ASSISTANT_ROLE,
ChatRole.system: _SYSTEM_ROLE,
}
BASE_URL = "https://integrate.api.nvidia.com/v1"


class NvidiaNIMApiKey(RequiredSecret):
Expand All @@ -43,30 +28,10 @@ def description(cls) -> SecretDescription:
)


class OpenAIChatMessage(BaseModel):
content: str
role: str
name: Optional[str] = None
tool_calls: Optional[List[Dict]] = None
tool_call_id: Optional[str] = None


class OpenAIChatRequest(BaseModel):
messages: List[OpenAIChatMessage]
model: str
frequency_penalty: Optional[float] = None
logit_bias: Optional[bool] = None
max_tokens: Optional[int] = 256
presence_penalty: Optional[float] = None
response_format: Optional[Dict] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = None
temperature: Optional[float] = 1.0
top_p: Optional[float] = None
tools: Optional[List] = None
tool_choice: Optional[Union[str, Dict]] = None
user: Optional[str] = None
class NIMOpenAIChatRequest(OpenAIChatRequest):
max_tokens: Optional[int] = (
256 # NVIDIA NIM uses the deprecated "max_tokens" param name instead of "max_completion_tokens"
)


@modelgauge_sut(
Expand All @@ -75,58 +40,23 @@ class OpenAIChatRequest(BaseModel):
AcceptsChatPrompt,
]
)
class NvidiaNIMApiClient(PromptResponseSUT):
class NvidiaNIMApiClient(OpenAIChat):
"""
Documented at https://https://docs.api.nvidia.com/
"""

def __init__(self, uid: str, model: str, api_key: NvidiaNIMApiKey):
super().__init__(uid)
self.model = model
self.client: Optional[OpenAI] = None
self.api_key = api_key.value

def _load_client(self) -> OpenAI:
return OpenAI(api_key=self.api_key, base_url="https://integrate.api.nvidia.com/v1")

def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> OpenAIChatRequest:
messages = [OpenAIChatMessage(content=prompt.text, role=_USER_ROLE)]
return self._translate_request(messages, options)

def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> OpenAIChatRequest:
messages = []
for message in prompt.messages:
messages.append(OpenAIChatMessage(content=message.text, role=_ROLE_MAP[message.role]))
return self._translate_request(messages, options)

def _translate_request(self, messages: List[OpenAIChatMessage], options: ModelOptions):
optional_kwargs: Dict[str, Any] = {}
return OpenAIChatRequest(
messages=messages,
model=self.model,
frequency_penalty=options.frequency_penalty,
super().__init__(uid, model, api_key=api_key, base_url=BASE_URL)

def _translate_request(self, messages, options: ModelOptions) -> NIMOpenAIChatRequest:
request = super()._translate_request(messages, options)
request_json = request.model_dump(exclude_none=True)
del request_json["max_completion_tokens"] # NIM API doesn't allow extra inputs
return NIMOpenAIChatRequest(
max_tokens=options.max_tokens,
presence_penalty=options.presence_penalty,
stop=options.stop_sequences,
top_p=options.top_p,
**optional_kwargs,
**request_json,
)

@retry(transient_exceptions=[APITimeoutError, ConflictError, InternalServerError, RateLimitError])
def evaluate(self, request: OpenAIChatRequest) -> ChatCompletion:
if self.client is None:
# Handle lazy init.
self.client = self._load_client()
request_dict = request.model_dump(exclude_none=True)
return self.client.chat.completions.create(**request_dict)

def translate_response(self, request: OpenAIChatRequest, response: ChatCompletion) -> SUTResponse:
assert len(response.choices) == 1, f"Expected a single response message, got {len(response.choices)}."
text = response.choices[0].message.content
if text is None:
text = ""
return SUTResponse(text=text)


SUTS.register(
NvidiaNIMApiClient,
Expand Down
12 changes: 4 additions & 8 deletions tests/modelgauge_tests/sut_tests/test_nvidia_nim_api_client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from modelgauge.suts.nvidia_nim_api_client import (
NvidiaNIMApiKey,
NvidiaNIMApiClient,
OpenAIChatMessage,
OpenAIChatRequest,
)
from openai.types.chat import ChatCompletion

from modelgauge.prompt import TextPrompt
from modelgauge.sut import SUTResponse
from modelgauge.suts.nvidia_nim_api_client import NIMOpenAIChatRequest, NvidiaNIMApiKey, NvidiaNIMApiClient
from modelgauge.suts.openai_client import OpenAIChatMessage
from modelgauge.model_options import ModelOptions


Expand All @@ -19,7 +15,7 @@ def test_openai_chat_translate_request():
client = _make_client()
prompt = TextPrompt(text="some-text")
request = client.translate_text_prompt(prompt, ModelOptions(max_tokens=100))
assert request == OpenAIChatRequest(
assert request == NIMOpenAIChatRequest(
model="some-model",
messages=[OpenAIChatMessage(content="some-text", role="user")],
max_tokens=100,
Expand All @@ -29,7 +25,7 @@ def test_openai_chat_translate_request():

def test_openai_chat_translate_response():
client = _make_client()
request = OpenAIChatRequest(
request = NIMOpenAIChatRequest(
model="some-model",
messages=[],
)
Expand Down
Loading