Skip to content

Commit 783266f

Browse files
slekkala1iamemilio
authored andcommitted
chore: Refactor fireworks to use OpenAIMixin (llamastack#3480)
# What does this PR do? Refactor Fireworks to use OpenAIMixin Closes llamastack#3391 Related to llamastack#3387 ## Test Plan ``` (llama-stack) (base) swapna942@swapna942-mac llama-stack % FIREWORKS_API_KEY=**** ./scripts/integration-tests.sh --stack-config server:ci-tests --setup fireworks --subdirs inference --pattern openai tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_single_string[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] instantiating llama_stack_client Port 8321 is already in use, assuming server is already running... llama_stack_client instantiated in 0.031s PASSED [ 2%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_multiple_strings[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 4%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_encoding_format_float[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 6%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_dimensions[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 8%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_user_parameter[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] SKIPPED [ 10%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_empty_list_error[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 12%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_invalid_model_error[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 14%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_different_inputs_different_outputs[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 17%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_encoding_format_base64[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] SKIPPED [ 19%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_base64_batch_processing[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] SKIPPED [ 21%] tests/integration/inference/test_openai_completion.py::test_openai_completion_non_streaming[txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:completion:sanity] PASSED [ 23%] tests/integration/inference/test_openai_completion.py::test_openai_completion_non_streaming_suffix[txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:completion:suffix] SKIPPED [ 25%] tests/integration/inference/test_openai_completion.py::test_openai_completion_streaming[txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:completion:sanity] PASSED [ 27%] tests/integration/inference/test_openai_completion.py::test_openai_completion_prompt_logprobs[txt=accounts/fireworks/models/llama-v3p1-8b-instruct-1] SKIPPED [ 29%] tests/integration/inference/test_openai_completion.py::test_openai_completion_guided_choice[txt=accounts/fireworks/models/llama-v3p1-8b-instruct] SKIPPED [ 31%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_non_streaming[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:non_streaming_01] PASSED [ 34%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:streaming_01] PASSED [ 36%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming_with_n[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:streaming_01] PASSED [ 38%] tests/integration/inference/test_openai_completion.py::test_inference_store[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-True] PASSED [ 40%] tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-True] PASSED [ 42%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_non_streaming_with_file[txt=accounts/fireworks/models/llama-v3p1-8b-instruct] SKIPPED [ 44%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_single_string[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 46%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_multiple_strings[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 48%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_encoding_format_float[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 51%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_dimensions[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 53%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_user_parameter[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] SKIPPED [ 55%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_empty_list_error[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 57%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_invalid_model_error[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 59%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_different_inputs_different_outputs[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 61%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_encoding_format_base64[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] SKIPPED [ 63%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_base64_batch_processing[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] SKIPPED [ 65%] tests/integration/inference/test_openai_completion.py::test_openai_completion_prompt_logprobs[txt=accounts/fireworks/models/llama-v3p1-8b-instruct-0] SKIPPED [ 68%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_non_streaming[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:non_streaming_02] PASSED [ 70%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:streaming_02] PASSED [ 72%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming_with_n[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:streaming_02] PASSED [ 74%] tests/integration/inference/test_openai_completion.py::test_inference_store[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-False] PASSED [ 76%] tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-False] PASSED [ 78%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_non_streaming[client_with_models-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:non_streaming_01] PASSED [ 80%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming[client_with_models-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:streaming_01] PASSED [ 82%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming_with_n[client_with_models-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:streaming_01] PASSED [ 85%] tests/integration/inference/test_openai_completion.py::test_inference_store[client_with_models-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-True] PASSED [ 87%] tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[client_with_models-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-True] PASSED [ 89%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_non_streaming[client_with_models-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:non_streaming_02] PASSED [ 91%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming[client_with_models-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:streaming_02] PASSED [ 93%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming_with_n[client_with_models-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:streaming_02] PASSED [ 95%] tests/integration/inference/test_openai_completion.py::test_inference_store[client_with_models-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-False] PASSED [ 97%] tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[client_with_models-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-False] PASSED [100%] ========================================== slowest 10 durations ========================================== 30.01s teardown tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_multiple_strings[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] 30.01s teardown tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[client_with_models-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-False] 30.01s teardown tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_different_inputs_different_outputs[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] 30.01s teardown tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_user_parameter[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] 30.01s teardown tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-True] 30.01s teardown tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_different_inputs_different_outputs[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] 30.01s teardown tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_non_streaming[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:non_streaming_02] 30.01s teardown tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_single_string[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] 30.01s teardown tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_base64_batch_processing[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] 30.01s teardown tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_invalid_model_error[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] ================= 36 passed, 11 skipped, 50 deselected, 4 warnings in 1429.05s (0:23:49) ================= + exit_code=0 + set +x ✅ All tests completed successfully ```
1 parent c85fd2f commit 783266f

File tree

3 files changed

+22
-168
lines changed

3 files changed

+22
-168
lines changed

llama_stack/providers/remote/inference/fireworks/fireworks.py

Lines changed: 11 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from collections.abc import AsyncGenerator, AsyncIterator
8-
from typing import Any
7+
from collections.abc import AsyncGenerator
98

109
from fireworks.client import Fireworks
11-
from openai import AsyncOpenAI
1210

1311
from llama_stack.apis.common.content_types import (
1412
InterleavedContent,
@@ -24,12 +22,6 @@
2422
Inference,
2523
LogProbConfig,
2624
Message,
27-
OpenAIChatCompletion,
28-
OpenAIChatCompletionChunk,
29-
OpenAICompletion,
30-
OpenAIEmbeddingsResponse,
31-
OpenAIMessageParam,
32-
OpenAIResponseFormatParam,
3325
ResponseFormat,
3426
ResponseFormatType,
3527
SamplingParams,
@@ -45,15 +37,14 @@
4537
ModelRegistryHelper,
4638
)
4739
from llama_stack.providers.utils.inference.openai_compat import (
48-
OpenAIChatCompletionToLlamaStackMixin,
4940
convert_message_to_openai_dict,
5041
get_sampling_options,
51-
prepare_openai_completion_params,
5242
process_chat_completion_response,
5343
process_chat_completion_stream_response,
5444
process_completion_response,
5545
process_completion_stream_response,
5646
)
47+
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
5748
from llama_stack.providers.utils.inference.prompt_adapter import (
5849
chat_completion_request_to_prompt,
5950
completion_request_to_prompt,
@@ -68,7 +59,7 @@
6859
logger = get_logger(name=__name__, category="inference::fireworks")
6960

7061

71-
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
62+
class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
7263
def __init__(self, config: FireworksImplConfig) -> None:
7364
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
7465
self.config = config
@@ -79,7 +70,7 @@ async def initialize(self) -> None:
7970
async def shutdown(self) -> None:
8071
pass
8172

82-
def _get_api_key(self) -> str:
73+
def get_api_key(self) -> str:
8374
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
8475
if config_api_key:
8576
return config_api_key
@@ -91,15 +82,18 @@ def _get_api_key(self) -> str:
9182
)
9283
return provider_data.fireworks_api_key
9384

94-
def _get_base_url(self) -> str:
85+
def get_base_url(self) -> str:
9586
return "https://api.fireworks.ai/inference/v1"
9687

9788
def _get_client(self) -> Fireworks:
98-
fireworks_api_key = self._get_api_key()
89+
fireworks_api_key = self.get_api_key()
9990
return Fireworks(api_key=fireworks_api_key)
10091

101-
def _get_openai_client(self) -> AsyncOpenAI:
102-
return AsyncOpenAI(base_url=self._get_base_url(), api_key=self._get_api_key())
92+
def _preprocess_prompt_for_fireworks(self, prompt: str) -> str:
93+
"""Remove BOS token as Fireworks automatically prepends it"""
94+
if prompt.startswith("<|begin_of_text|>"):
95+
return prompt[len("<|begin_of_text|>") :]
96+
return prompt
10397

10498
async def completion(
10599
self,
@@ -285,153 +279,3 @@ async def embeddings(
285279

286280
embeddings = [data.embedding for data in response.data]
287281
return EmbeddingsResponse(embeddings=embeddings)
288-
289-
async def openai_embeddings(
290-
self,
291-
model: str,
292-
input: str | list[str],
293-
encoding_format: str | None = "float",
294-
dimensions: int | None = None,
295-
user: str | None = None,
296-
) -> OpenAIEmbeddingsResponse:
297-
raise NotImplementedError()
298-
299-
async def openai_completion(
300-
self,
301-
model: str,
302-
prompt: str | list[str] | list[int] | list[list[int]],
303-
best_of: int | None = None,
304-
echo: bool | None = None,
305-
frequency_penalty: float | None = None,
306-
logit_bias: dict[str, float] | None = None,
307-
logprobs: bool | None = None,
308-
max_tokens: int | None = None,
309-
n: int | None = None,
310-
presence_penalty: float | None = None,
311-
seed: int | None = None,
312-
stop: str | list[str] | None = None,
313-
stream: bool | None = None,
314-
stream_options: dict[str, Any] | None = None,
315-
temperature: float | None = None,
316-
top_p: float | None = None,
317-
user: str | None = None,
318-
guided_choice: list[str] | None = None,
319-
prompt_logprobs: int | None = None,
320-
suffix: str | None = None,
321-
) -> OpenAICompletion:
322-
model_obj = await self.model_store.get_model(model)
323-
324-
# Fireworks always prepends with BOS
325-
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
326-
prompt = prompt[len("<|begin_of_text|>") :]
327-
328-
params = await prepare_openai_completion_params(
329-
model=model_obj.provider_resource_id,
330-
prompt=prompt,
331-
best_of=best_of,
332-
echo=echo,
333-
frequency_penalty=frequency_penalty,
334-
logit_bias=logit_bias,
335-
logprobs=logprobs,
336-
max_tokens=max_tokens,
337-
n=n,
338-
presence_penalty=presence_penalty,
339-
seed=seed,
340-
stop=stop,
341-
stream=stream,
342-
stream_options=stream_options,
343-
temperature=temperature,
344-
top_p=top_p,
345-
user=user,
346-
)
347-
348-
return await self._get_openai_client().completions.create(**params)
349-
350-
async def openai_chat_completion(
351-
self,
352-
model: str,
353-
messages: list[OpenAIMessageParam],
354-
frequency_penalty: float | None = None,
355-
function_call: str | dict[str, Any] | None = None,
356-
functions: list[dict[str, Any]] | None = None,
357-
logit_bias: dict[str, float] | None = None,
358-
logprobs: bool | None = None,
359-
max_completion_tokens: int | None = None,
360-
max_tokens: int | None = None,
361-
n: int | None = None,
362-
parallel_tool_calls: bool | None = None,
363-
presence_penalty: float | None = None,
364-
response_format: OpenAIResponseFormatParam | None = None,
365-
seed: int | None = None,
366-
stop: str | list[str] | None = None,
367-
stream: bool | None = None,
368-
stream_options: dict[str, Any] | None = None,
369-
temperature: float | None = None,
370-
tool_choice: str | dict[str, Any] | None = None,
371-
tools: list[dict[str, Any]] | None = None,
372-
top_logprobs: int | None = None,
373-
top_p: float | None = None,
374-
user: str | None = None,
375-
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
376-
model_obj = await self.model_store.get_model(model)
377-
378-
# Divert Llama Models through Llama Stack inference APIs because
379-
# Fireworks chat completions OpenAI-compatible API does not support
380-
# tool calls properly.
381-
llama_model = self.get_llama_model(model_obj.provider_resource_id)
382-
383-
if llama_model:
384-
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
385-
self,
386-
model=model,
387-
messages=messages,
388-
frequency_penalty=frequency_penalty,
389-
function_call=function_call,
390-
functions=functions,
391-
logit_bias=logit_bias,
392-
logprobs=logprobs,
393-
max_completion_tokens=max_completion_tokens,
394-
max_tokens=max_tokens,
395-
n=n,
396-
parallel_tool_calls=parallel_tool_calls,
397-
presence_penalty=presence_penalty,
398-
response_format=response_format,
399-
seed=seed,
400-
stop=stop,
401-
stream=stream,
402-
stream_options=stream_options,
403-
temperature=temperature,
404-
tool_choice=tool_choice,
405-
tools=tools,
406-
top_logprobs=top_logprobs,
407-
top_p=top_p,
408-
user=user,
409-
)
410-
411-
params = await prepare_openai_completion_params(
412-
messages=messages,
413-
frequency_penalty=frequency_penalty,
414-
function_call=function_call,
415-
functions=functions,
416-
logit_bias=logit_bias,
417-
logprobs=logprobs,
418-
max_completion_tokens=max_completion_tokens,
419-
max_tokens=max_tokens,
420-
n=n,
421-
parallel_tool_calls=parallel_tool_calls,
422-
presence_penalty=presence_penalty,
423-
response_format=response_format,
424-
seed=seed,
425-
stop=stop,
426-
stream=stream,
427-
stream_options=stream_options,
428-
temperature=temperature,
429-
tool_choice=tool_choice,
430-
tools=tools,
431-
top_logprobs=top_logprobs,
432-
top_p=top_p,
433-
user=user,
434-
)
435-
436-
logger.debug(f"fireworks params: {params}")
437-
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)

tests/integration/inference/test_openai_embeddings.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def skip_if_model_doesnt_support_user_param(client, model_id):
3333
provider = provider_from_model(client, model_id)
3434
if provider.provider_type in (
3535
"remote::together", # service returns 400
36+
"remote::fireworks", # service returns 400 malformed input
3637
):
3738
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support user param.")
3839

@@ -41,6 +42,7 @@ def skip_if_model_doesnt_support_encoding_format_base64(client, model_id):
4142
provider = provider_from_model(client, model_id)
4243
if provider.provider_type in (
4344
"remote::together", # param silently ignored, always returns floats
45+
"remote::fireworks", # param silently ignored, always returns list of floats
4446
):
4547
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.")
4648

@@ -287,7 +289,6 @@ def test_openai_embeddings_base64_batch_processing(compat_client, client_with_mo
287289
input=input_texts,
288290
encoding_format="base64",
289291
)
290-
291292
# Validate response structure
292293
assert response.object == "list"
293294
assert response.model == embedding_model_id

tests/integration/suites.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,15 @@ class Setup(BaseModel):
108108
"embedding_model": "together/togethercomputer/m2-bert-80M-32k-retrieval",
109109
},
110110
),
111+
"fireworks": Setup(
112+
name="fireworks",
113+
description="Fireworks provider with a text model",
114+
defaults={
115+
"text_model": "accounts/fireworks/models/llama-v3p1-8b-instruct",
116+
"vision_model": "accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
117+
"embedding_model": "nomic-ai/nomic-embed-text-v1.5",
118+
},
119+
),
111120
}
112121

113122

0 commit comments

Comments
 (0)