Skip to content

Commit 19f931b

Browse files
authored
NVIDIA NIM SUT inherits from OpenAI SUT (#1502)
* first draft * nim sut inherits from openai * oops, delete from wrong branch
1 parent e735956 commit 19f931b

File tree

2 files changed

+20
-94
lines changed

2 files changed

+20
-94
lines changed

src/modelgauge/suts/nvidia_nim_api_client.py

Lines changed: 16 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,11 @@
1-
from typing import Any, Dict, List, Optional, Union
1+
from typing import Optional
22

3-
from openai import OpenAI
4-
from openai import APITimeoutError, ConflictError, InternalServerError, RateLimitError
5-
from openai.types.chat import ChatCompletion
6-
from pydantic import BaseModel
7-
8-
from modelgauge.prompt import ChatPrompt, ChatRole, TextPrompt
9-
from modelgauge.retry_decorator import retry
103
from modelgauge.secret_values import (
114
InjectSecret,
125
RequiredSecret,
136
SecretDescription,
147
)
15-
from modelgauge.sut import PromptResponseSUT, SUTResponse
8+
from modelgauge.suts.openai_client import OpenAIChat, OpenAIChatRequest
169
from modelgauge.model_options import ModelOptions
1710
from modelgauge.sut_capabilities import (
1811
AcceptsChatPrompt,
@@ -21,16 +14,8 @@
2114
from modelgauge.sut_decorator import modelgauge_sut
2215
from modelgauge.sut_registry import SUTS
2316

24-
_SYSTEM_ROLE = "system"
25-
_USER_ROLE = "user"
26-
_ASSISTANT_ROLE = "assistant"
27-
_TOOL_ROLE = "tool_call_id"
2817

29-
_ROLE_MAP = {
30-
ChatRole.user: _USER_ROLE,
31-
ChatRole.sut: _ASSISTANT_ROLE,
32-
ChatRole.system: _SYSTEM_ROLE,
33-
}
18+
BASE_URL = "https://integrate.api.nvidia.com/v1"
3419

3520

3621
class NvidiaNIMApiKey(RequiredSecret):
@@ -43,30 +28,10 @@ def description(cls) -> SecretDescription:
4328
)
4429

4530

46-
class OpenAIChatMessage(BaseModel):
47-
content: str
48-
role: str
49-
name: Optional[str] = None
50-
tool_calls: Optional[List[Dict]] = None
51-
tool_call_id: Optional[str] = None
52-
53-
54-
class OpenAIChatRequest(BaseModel):
55-
messages: List[OpenAIChatMessage]
56-
model: str
57-
frequency_penalty: Optional[float] = None
58-
logit_bias: Optional[bool] = None
59-
max_tokens: Optional[int] = 256
60-
presence_penalty: Optional[float] = None
61-
response_format: Optional[Dict] = None
62-
seed: Optional[int] = None
63-
stop: Optional[Union[str, List[str]]] = None
64-
stream: Optional[bool] = None
65-
temperature: Optional[float] = 1.0
66-
top_p: Optional[float] = None
67-
tools: Optional[List] = None
68-
tool_choice: Optional[Union[str, Dict]] = None
69-
user: Optional[str] = None
31+
class NIMOpenAIChatRequest(OpenAIChatRequest):
32+
max_tokens: Optional[int] = (
33+
256 # NVIDIA NIM uses the deprecated "max_tokens" param name instead of "max_completion_tokens"
34+
)
7035

7136

7237
@modelgauge_sut(
@@ -75,58 +40,23 @@ class OpenAIChatRequest(BaseModel):
7540
AcceptsChatPrompt,
7641
]
7742
)
78-
class NvidiaNIMApiClient(PromptResponseSUT):
43+
class NvidiaNIMApiClient(OpenAIChat):
7944
"""
8045
Documented at https://https://docs.api.nvidia.com/
8146
"""
8247

8348
def __init__(self, uid: str, model: str, api_key: NvidiaNIMApiKey):
84-
super().__init__(uid)
85-
self.model = model
86-
self.client: Optional[OpenAI] = None
87-
self.api_key = api_key.value
88-
89-
def _load_client(self) -> OpenAI:
90-
return OpenAI(api_key=self.api_key, base_url="https://integrate.api.nvidia.com/v1")
91-
92-
def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> OpenAIChatRequest:
93-
messages = [OpenAIChatMessage(content=prompt.text, role=_USER_ROLE)]
94-
return self._translate_request(messages, options)
95-
96-
def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> OpenAIChatRequest:
97-
messages = []
98-
for message in prompt.messages:
99-
messages.append(OpenAIChatMessage(content=message.text, role=_ROLE_MAP[message.role]))
100-
return self._translate_request(messages, options)
101-
102-
def _translate_request(self, messages: List[OpenAIChatMessage], options: ModelOptions):
103-
optional_kwargs: Dict[str, Any] = {}
104-
return OpenAIChatRequest(
105-
messages=messages,
106-
model=self.model,
107-
frequency_penalty=options.frequency_penalty,
49+
super().__init__(uid, model, api_key=api_key, base_url=BASE_URL)
50+
51+
def _translate_request(self, messages, options: ModelOptions) -> NIMOpenAIChatRequest:
52+
request = super()._translate_request(messages, options)
53+
request_json = request.model_dump(exclude_none=True)
54+
del request_json["max_completion_tokens"] # NIM API doesn't allow extra inputs
55+
return NIMOpenAIChatRequest(
10856
max_tokens=options.max_tokens,
109-
presence_penalty=options.presence_penalty,
110-
stop=options.stop_sequences,
111-
top_p=options.top_p,
112-
**optional_kwargs,
57+
**request_json,
11358
)
11459

115-
@retry(transient_exceptions=[APITimeoutError, ConflictError, InternalServerError, RateLimitError])
116-
def evaluate(self, request: OpenAIChatRequest) -> ChatCompletion:
117-
if self.client is None:
118-
# Handle lazy init.
119-
self.client = self._load_client()
120-
request_dict = request.model_dump(exclude_none=True)
121-
return self.client.chat.completions.create(**request_dict)
122-
123-
def translate_response(self, request: OpenAIChatRequest, response: ChatCompletion) -> SUTResponse:
124-
assert len(response.choices) == 1, f"Expected a single response message, got {len(response.choices)}."
125-
text = response.choices[0].message.content
126-
if text is None:
127-
text = ""
128-
return SUTResponse(text=text)
129-
13060

13161
SUTS.register(
13262
NvidiaNIMApiClient,

tests/modelgauge_tests/sut_tests/test_nvidia_nim_api_client.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
1-
from modelgauge.suts.nvidia_nim_api_client import (
2-
NvidiaNIMApiKey,
3-
NvidiaNIMApiClient,
4-
OpenAIChatMessage,
5-
OpenAIChatRequest,
6-
)
71
from openai.types.chat import ChatCompletion
82

93
from modelgauge.prompt import TextPrompt
104
from modelgauge.sut import SUTResponse
5+
from modelgauge.suts.nvidia_nim_api_client import NIMOpenAIChatRequest, NvidiaNIMApiKey, NvidiaNIMApiClient
6+
from modelgauge.suts.openai_client import OpenAIChatMessage
117
from modelgauge.model_options import ModelOptions
128

139

@@ -19,7 +15,7 @@ def test_openai_chat_translate_request():
1915
client = _make_client()
2016
prompt = TextPrompt(text="some-text")
2117
request = client.translate_text_prompt(prompt, ModelOptions(max_tokens=100))
22-
assert request == OpenAIChatRequest(
18+
assert request == NIMOpenAIChatRequest(
2319
model="some-model",
2420
messages=[OpenAIChatMessage(content="some-text", role="user")],
2521
max_tokens=100,
@@ -29,7 +25,7 @@ def test_openai_chat_translate_request():
2925

3026
def test_openai_chat_translate_response():
3127
client = _make_client()
32-
request = OpenAIChatRequest(
28+
request = NIMOpenAIChatRequest(
3329
model="some-model",
3430
messages=[],
3531
)

0 commit comments

Comments
 (0)