44# This source code is licensed under the terms described in the LICENSE file in 
55# the root directory of this source tree. 
66
7- from  collections .abc  import  AsyncGenerator , AsyncIterator 
8- from  typing  import  Any 
7+ from  collections .abc  import  AsyncGenerator 
98
109from  fireworks .client  import  Fireworks 
11- from  openai  import  AsyncOpenAI 
1210
1311from  llama_stack .apis .common .content_types  import  (
1412    InterleavedContent ,
2422    Inference ,
2523    LogProbConfig ,
2624    Message ,
27-     OpenAIChatCompletion ,
28-     OpenAIChatCompletionChunk ,
29-     OpenAICompletion ,
30-     OpenAIEmbeddingsResponse ,
31-     OpenAIMessageParam ,
32-     OpenAIResponseFormatParam ,
3325    ResponseFormat ,
3426    ResponseFormatType ,
3527    SamplingParams ,
4537    ModelRegistryHelper ,
4638)
4739from  llama_stack .providers .utils .inference .openai_compat  import  (
48-     OpenAIChatCompletionToLlamaStackMixin ,
4940    convert_message_to_openai_dict ,
5041    get_sampling_options ,
51-     prepare_openai_completion_params ,
5242    process_chat_completion_response ,
5343    process_chat_completion_stream_response ,
5444    process_completion_response ,
5545    process_completion_stream_response ,
5646)
47+ from  llama_stack .providers .utils .inference .openai_mixin  import  OpenAIMixin 
5748from  llama_stack .providers .utils .inference .prompt_adapter  import  (
5849    chat_completion_request_to_prompt ,
5950    completion_request_to_prompt ,
6859logger  =  get_logger (name = __name__ , category = "inference::fireworks" )
6960
7061
71- class  FireworksInferenceAdapter (ModelRegistryHelper , Inference , NeedsRequestProviderData ):
62+ class  FireworksInferenceAdapter (OpenAIMixin ,  ModelRegistryHelper , Inference , NeedsRequestProviderData ):
7263    def  __init__ (self , config : FireworksImplConfig ) ->  None :
7364        ModelRegistryHelper .__init__ (self , MODEL_ENTRIES , config .allowed_models )
7465        self .config  =  config 
@@ -79,7 +70,7 @@ async def initialize(self) -> None:
7970    async  def  shutdown (self ) ->  None :
8071        pass 
8172
82-     def  _get_api_key (self ) ->  str :
73+     def  get_api_key (self ) ->  str :
8374        config_api_key  =  self .config .api_key .get_secret_value () if  self .config .api_key  else  None 
8475        if  config_api_key :
8576            return  config_api_key 
@@ -91,15 +82,18 @@ def _get_api_key(self) -> str:
9182                )
9283            return  provider_data .fireworks_api_key 
9384
94-     def  _get_base_url (self ) ->  str :
85+     def  get_base_url (self ) ->  str :
9586        return  "https://api.fireworks.ai/inference/v1" 
9687
9788    def  _get_client (self ) ->  Fireworks :
98-         fireworks_api_key  =  self ._get_api_key ()
89+         fireworks_api_key  =  self .get_api_key ()
9990        return  Fireworks (api_key = fireworks_api_key )
10091
101-     def  _get_openai_client (self ) ->  AsyncOpenAI :
102-         return  AsyncOpenAI (base_url = self ._get_base_url (), api_key = self ._get_api_key ())
92+     def  _preprocess_prompt_for_fireworks (self , prompt : str ) ->  str :
93+         """Remove BOS token as Fireworks automatically prepends it""" 
94+         if  prompt .startswith ("<|begin_of_text|>" ):
95+             return  prompt [len ("<|begin_of_text|>" ) :]
96+         return  prompt 
10397
10498    async  def  completion (
10599        self ,
@@ -285,153 +279,3 @@ async def embeddings(
285279
286280        embeddings  =  [data .embedding  for  data  in  response .data ]
287281        return  EmbeddingsResponse (embeddings = embeddings )
288- 
289-     async  def  openai_embeddings (
290-         self ,
291-         model : str ,
292-         input : str  |  list [str ],
293-         encoding_format : str  |  None  =  "float" ,
294-         dimensions : int  |  None  =  None ,
295-         user : str  |  None  =  None ,
296-     ) ->  OpenAIEmbeddingsResponse :
297-         raise  NotImplementedError ()
298- 
299-     async  def  openai_completion (
300-         self ,
301-         model : str ,
302-         prompt : str  |  list [str ] |  list [int ] |  list [list [int ]],
303-         best_of : int  |  None  =  None ,
304-         echo : bool  |  None  =  None ,
305-         frequency_penalty : float  |  None  =  None ,
306-         logit_bias : dict [str , float ] |  None  =  None ,
307-         logprobs : bool  |  None  =  None ,
308-         max_tokens : int  |  None  =  None ,
309-         n : int  |  None  =  None ,
310-         presence_penalty : float  |  None  =  None ,
311-         seed : int  |  None  =  None ,
312-         stop : str  |  list [str ] |  None  =  None ,
313-         stream : bool  |  None  =  None ,
314-         stream_options : dict [str , Any ] |  None  =  None ,
315-         temperature : float  |  None  =  None ,
316-         top_p : float  |  None  =  None ,
317-         user : str  |  None  =  None ,
318-         guided_choice : list [str ] |  None  =  None ,
319-         prompt_logprobs : int  |  None  =  None ,
320-         suffix : str  |  None  =  None ,
321-     ) ->  OpenAICompletion :
322-         model_obj  =  await  self .model_store .get_model (model )
323- 
324-         # Fireworks always prepends with BOS 
325-         if  isinstance (prompt , str ) and  prompt .startswith ("<|begin_of_text|>" ):
326-             prompt  =  prompt [len ("<|begin_of_text|>" ) :]
327- 
328-         params  =  await  prepare_openai_completion_params (
329-             model = model_obj .provider_resource_id ,
330-             prompt = prompt ,
331-             best_of = best_of ,
332-             echo = echo ,
333-             frequency_penalty = frequency_penalty ,
334-             logit_bias = logit_bias ,
335-             logprobs = logprobs ,
336-             max_tokens = max_tokens ,
337-             n = n ,
338-             presence_penalty = presence_penalty ,
339-             seed = seed ,
340-             stop = stop ,
341-             stream = stream ,
342-             stream_options = stream_options ,
343-             temperature = temperature ,
344-             top_p = top_p ,
345-             user = user ,
346-         )
347- 
348-         return  await  self ._get_openai_client ().completions .create (** params )
349- 
350-     async  def  openai_chat_completion (
351-         self ,
352-         model : str ,
353-         messages : list [OpenAIMessageParam ],
354-         frequency_penalty : float  |  None  =  None ,
355-         function_call : str  |  dict [str , Any ] |  None  =  None ,
356-         functions : list [dict [str , Any ]] |  None  =  None ,
357-         logit_bias : dict [str , float ] |  None  =  None ,
358-         logprobs : bool  |  None  =  None ,
359-         max_completion_tokens : int  |  None  =  None ,
360-         max_tokens : int  |  None  =  None ,
361-         n : int  |  None  =  None ,
362-         parallel_tool_calls : bool  |  None  =  None ,
363-         presence_penalty : float  |  None  =  None ,
364-         response_format : OpenAIResponseFormatParam  |  None  =  None ,
365-         seed : int  |  None  =  None ,
366-         stop : str  |  list [str ] |  None  =  None ,
367-         stream : bool  |  None  =  None ,
368-         stream_options : dict [str , Any ] |  None  =  None ,
369-         temperature : float  |  None  =  None ,
370-         tool_choice : str  |  dict [str , Any ] |  None  =  None ,
371-         tools : list [dict [str , Any ]] |  None  =  None ,
372-         top_logprobs : int  |  None  =  None ,
373-         top_p : float  |  None  =  None ,
374-         user : str  |  None  =  None ,
375-     ) ->  OpenAIChatCompletion  |  AsyncIterator [OpenAIChatCompletionChunk ]:
376-         model_obj  =  await  self .model_store .get_model (model )
377- 
378-         # Divert Llama Models through Llama Stack inference APIs because 
379-         # Fireworks chat completions OpenAI-compatible API does not support 
380-         # tool calls properly. 
381-         llama_model  =  self .get_llama_model (model_obj .provider_resource_id )
382- 
383-         if  llama_model :
384-             return  await  OpenAIChatCompletionToLlamaStackMixin .openai_chat_completion (
385-                 self ,
386-                 model = model ,
387-                 messages = messages ,
388-                 frequency_penalty = frequency_penalty ,
389-                 function_call = function_call ,
390-                 functions = functions ,
391-                 logit_bias = logit_bias ,
392-                 logprobs = logprobs ,
393-                 max_completion_tokens = max_completion_tokens ,
394-                 max_tokens = max_tokens ,
395-                 n = n ,
396-                 parallel_tool_calls = parallel_tool_calls ,
397-                 presence_penalty = presence_penalty ,
398-                 response_format = response_format ,
399-                 seed = seed ,
400-                 stop = stop ,
401-                 stream = stream ,
402-                 stream_options = stream_options ,
403-                 temperature = temperature ,
404-                 tool_choice = tool_choice ,
405-                 tools = tools ,
406-                 top_logprobs = top_logprobs ,
407-                 top_p = top_p ,
408-                 user = user ,
409-             )
410- 
411-         params  =  await  prepare_openai_completion_params (
412-             messages = messages ,
413-             frequency_penalty = frequency_penalty ,
414-             function_call = function_call ,
415-             functions = functions ,
416-             logit_bias = logit_bias ,
417-             logprobs = logprobs ,
418-             max_completion_tokens = max_completion_tokens ,
419-             max_tokens = max_tokens ,
420-             n = n ,
421-             parallel_tool_calls = parallel_tool_calls ,
422-             presence_penalty = presence_penalty ,
423-             response_format = response_format ,
424-             seed = seed ,
425-             stop = stop ,
426-             stream = stream ,
427-             stream_options = stream_options ,
428-             temperature = temperature ,
429-             tool_choice = tool_choice ,
430-             tools = tools ,
431-             top_logprobs = top_logprobs ,
432-             top_p = top_p ,
433-             user = user ,
434-         )
435- 
436-         logger .debug (f"fireworks params: { params }  )
437-         return  await  self ._get_openai_client ().chat .completions .create (model = model_obj .provider_resource_id , ** params )
0 commit comments