1- from typing import Any , Dict , List , Optional , Union
1+ from typing import Optional
22
3- from openai import OpenAI
4- from openai import APITimeoutError , ConflictError , InternalServerError , RateLimitError
5- from openai .types .chat import ChatCompletion
6- from pydantic import BaseModel
7-
8- from modelgauge .prompt import ChatPrompt , ChatRole , TextPrompt
9- from modelgauge .retry_decorator import retry
103from modelgauge .secret_values import (
114 InjectSecret ,
125 RequiredSecret ,
136 SecretDescription ,
147)
15- from modelgauge .sut import PromptResponseSUT , SUTResponse
8+ from modelgauge .suts . openai_client import OpenAIChat , OpenAIChatRequest
169from modelgauge .model_options import ModelOptions
1710from modelgauge .sut_capabilities import (
1811 AcceptsChatPrompt ,
2114from modelgauge .sut_decorator import modelgauge_sut
2215from modelgauge .sut_registry import SUTS
2316
24- _SYSTEM_ROLE = "system"
25- _USER_ROLE = "user"
26- _ASSISTANT_ROLE = "assistant"
27- _TOOL_ROLE = "tool_call_id"
2817
29- _ROLE_MAP = {
30- ChatRole .user : _USER_ROLE ,
31- ChatRole .sut : _ASSISTANT_ROLE ,
32- ChatRole .system : _SYSTEM_ROLE ,
33- }
18+ BASE_URL = "https://integrate.api.nvidia.com/v1"
3419
3520
3621class NvidiaNIMApiKey (RequiredSecret ):
@@ -43,30 +28,10 @@ def description(cls) -> SecretDescription:
4328 )
4429
4530
46- class OpenAIChatMessage (BaseModel ):
47- content : str
48- role : str
49- name : Optional [str ] = None
50- tool_calls : Optional [List [Dict ]] = None
51- tool_call_id : Optional [str ] = None
52-
53-
54- class OpenAIChatRequest (BaseModel ):
55- messages : List [OpenAIChatMessage ]
56- model : str
57- frequency_penalty : Optional [float ] = None
58- logit_bias : Optional [bool ] = None
59- max_tokens : Optional [int ] = 256
60- presence_penalty : Optional [float ] = None
61- response_format : Optional [Dict ] = None
62- seed : Optional [int ] = None
63- stop : Optional [Union [str , List [str ]]] = None
64- stream : Optional [bool ] = None
65- temperature : Optional [float ] = 1.0
66- top_p : Optional [float ] = None
67- tools : Optional [List ] = None
68- tool_choice : Optional [Union [str , Dict ]] = None
69- user : Optional [str ] = None
31+ class NIMOpenAIChatRequest (OpenAIChatRequest ):
32+ max_tokens : Optional [int ] = (
33+ 256 # NVIDIA NIM uses the deprecated "max_tokens" param name instead of "max_completion_tokens"
34+ )
7035
7136
7237@modelgauge_sut (
@@ -75,58 +40,23 @@ class OpenAIChatRequest(BaseModel):
7540 AcceptsChatPrompt ,
7641 ]
7742)
78- class NvidiaNIMApiClient (PromptResponseSUT ):
43+ class NvidiaNIMApiClient (OpenAIChat ):
7944 """
8045 Documented at https://https://docs.api.nvidia.com/
8146 """
8247
8348 def __init__ (self , uid : str , model : str , api_key : NvidiaNIMApiKey ):
84- super ().__init__ (uid )
85- self .model = model
86- self .client : Optional [OpenAI ] = None
87- self .api_key = api_key .value
88-
89- def _load_client (self ) -> OpenAI :
90- return OpenAI (api_key = self .api_key , base_url = "https://integrate.api.nvidia.com/v1" )
91-
92- def translate_text_prompt (self , prompt : TextPrompt , options : ModelOptions ) -> OpenAIChatRequest :
93- messages = [OpenAIChatMessage (content = prompt .text , role = _USER_ROLE )]
94- return self ._translate_request (messages , options )
95-
96- def translate_chat_prompt (self , prompt : ChatPrompt , options : ModelOptions ) -> OpenAIChatRequest :
97- messages = []
98- for message in prompt .messages :
99- messages .append (OpenAIChatMessage (content = message .text , role = _ROLE_MAP [message .role ]))
100- return self ._translate_request (messages , options )
101-
102- def _translate_request (self , messages : List [OpenAIChatMessage ], options : ModelOptions ):
103- optional_kwargs : Dict [str , Any ] = {}
104- return OpenAIChatRequest (
105- messages = messages ,
106- model = self .model ,
107- frequency_penalty = options .frequency_penalty ,
49+ super ().__init__ (uid , model , api_key = api_key , base_url = BASE_URL )
50+
51+ def _translate_request (self , messages , options : ModelOptions ) -> NIMOpenAIChatRequest :
52+ request = super ()._translate_request (messages , options )
53+ request_json = request .model_dump (exclude_none = True )
54+ del request_json ["max_completion_tokens" ] # NIM API doesn't allow extra inputs
55+ return NIMOpenAIChatRequest (
10856 max_tokens = options .max_tokens ,
109- presence_penalty = options .presence_penalty ,
110- stop = options .stop_sequences ,
111- top_p = options .top_p ,
112- ** optional_kwargs ,
57+ ** request_json ,
11358 )
11459
115- @retry (transient_exceptions = [APITimeoutError , ConflictError , InternalServerError , RateLimitError ])
116- def evaluate (self , request : OpenAIChatRequest ) -> ChatCompletion :
117- if self .client is None :
118- # Handle lazy init.
119- self .client = self ._load_client ()
120- request_dict = request .model_dump (exclude_none = True )
121- return self .client .chat .completions .create (** request_dict )
122-
123- def translate_response (self , request : OpenAIChatRequest , response : ChatCompletion ) -> SUTResponse :
124- assert len (response .choices ) == 1 , f"Expected a single response message, got { len (response .choices )} ."
125- text = response .choices [0 ].message .content
126- if text is None :
127- text = ""
128- return SUTResponse (text = text )
129-
13060
13161SUTS .register (
13262 NvidiaNIMApiClient ,
0 commit comments