Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/fish_audio_sdk/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,20 @@ def __init__(self, apikey: str, *, base_url: str = "https://api.fish.audio"):
def init_async_client(self):
self._async_client = httpx.AsyncClient(
base_url=self._base_url,
headers={"Authorization": f"Bearer {self._apikey}"},
headers={
"Authorization": f"Bearer {self._apikey}",
"User-Agent": "fish-audio/python/legacy",
},
timeout=None,
)

def init_sync_client(self):
self._sync_client = httpx.Client(
base_url=self._base_url,
headers={"Authorization": f"Bearer {self._apikey}"},
headers={
"Authorization": f"Bearer {self._apikey}",
"User-Agent": "fish-audio/python/legacy",
},
timeout=None,
)

Expand Down
10 changes: 8 additions & 2 deletions src/fish_audio_sdk/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ def __init__(
self._executor = ThreadPoolExecutor(max_workers=max_workers)
self._client = httpx.Client(
base_url=self._base_url,
headers={"Authorization": f"Bearer {self._apikey}"},
headers={
"Authorization": f"Bearer {self._apikey}",
"User-Agent": "fish-audio/python/legacy",
},
)

def __enter__(self):
Expand Down Expand Up @@ -97,7 +100,10 @@ def __init__(
self._base_url = base_url
self._client = httpx.AsyncClient(
base_url=self._base_url,
headers={"Authorization": f"Bearer {self._apikey}"},
headers={
"Authorization": f"Bearer {self._apikey}",
"User-Agent": "fish-audio/python/legacy",
},
)

async def __aenter__(self):
Expand Down
10 changes: 5 additions & 5 deletions src/fishaudio/core/client_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ def __init__(
)
self.base_url = base_url

def _get_headers(
def get_headers(
self, additional_headers: Optional[Dict[str, str]] = None
) -> Dict[str, str]:
"""Build headers including authentication."""
"""Build headers including authentication and user agent."""
headers = {
"Authorization": f"Bearer {self.api_key}",
"User-Agent": f"fish-audio/python/{__version__}",
Expand All @@ -77,7 +77,7 @@ def _prepare_request_kwargs(
) -> None:
"""Prepare request kwargs by merging headers, timeout, and query params."""
# Merge headers
headers = self._get_headers()
headers = self.get_headers()
if request_options and request_options.additional_headers:
headers.update(request_options.additional_headers)
kwargs["headers"] = {**headers, **kwargs.get("headers", {})}
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
self._client = httpx.Client(
base_url=base_url,
timeout=httpx.Timeout(timeout),
headers=self._get_headers(),
headers=self.get_headers(),
)

def request(
Expand Down Expand Up @@ -185,7 +185,7 @@ def __init__(
self._client = httpx.AsyncClient(
base_url=base_url,
timeout=httpx.Timeout(timeout),
headers=self._get_headers(),
headers=self.get_headers(),
)

async def request(
Expand Down
7 changes: 2 additions & 5 deletions src/fishaudio/resources/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,7 @@ def text_generator():
with connect_ws(
"/v1/tts/live",
client=self._client.client,
headers={
"model": model,
"Authorization": f"Bearer {self._client.api_key}",
},
headers=self._client.get_headers({"model": model}),
**ws_kwargs,
) as ws:

Expand Down Expand Up @@ -630,7 +627,7 @@ async def text_generator():
async with aconnect_ws(
"/v1/tts/live",
client=self._client.client,
headers={"model": model, "Authorization": f"Bearer {self._client.api_key}"},
headers=self._client.get_headers({"model": model}),
**ws_kwargs,
) as ws:

Expand Down
6 changes: 3 additions & 3 deletions tests/unit/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,13 @@ def test_init_with_env_var(self, mock_api_key):

def test_get_headers(self, mock_api_key):
wrapper = ClientWrapper(api_key=mock_api_key)
headers = wrapper._get_headers()
headers = wrapper.get_headers()
assert headers["Authorization"] == f"Bearer {mock_api_key}"
assert "User-Agent" in headers

def test_get_headers_with_additional(self, mock_api_key):
wrapper = ClientWrapper(api_key=mock_api_key)
headers = wrapper._get_headers({"X-Custom": "value"})
headers = wrapper.get_headers({"X-Custom": "value"})
assert headers["X-Custom"] == "value"
assert headers["Authorization"] == f"Bearer {mock_api_key}"

Expand All @@ -139,6 +139,6 @@ def test_init_without_api_key_raises(self):

def test_get_headers(self, mock_api_key):
wrapper = AsyncClientWrapper(api_key=mock_api_key)
headers = wrapper._get_headers()
headers = wrapper.get_headers()
assert headers["Authorization"] == f"Bearer {mock_api_key}"
assert "User-Agent" in headers
12 changes: 12 additions & 0 deletions tests/unit/test_tts_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ def mock_client_wrapper(mock_api_key):
wrapper.api_key = mock_api_key
# Mock the underlying httpx.Client
wrapper._client = Mock()
# Mock get_headers to return a dict with the additional headers merged
wrapper.get_headers = lambda additional=None: {
"Authorization": f"Bearer {mock_api_key}",
"User-Agent": "fish-audio/python/test",
**(additional or {}),
}
return wrapper


Expand All @@ -26,6 +32,12 @@ def async_mock_client_wrapper(mock_api_key):
wrapper.api_key = mock_api_key
# Mock the underlying httpx.AsyncClient
wrapper._client = Mock()
# Mock get_headers to return a dict with the additional headers merged
wrapper.get_headers = lambda additional=None: {
"Authorization": f"Bearer {mock_api_key}",
"User-Agent": "fish-audio/python/test",
**(additional or {}),
}
return wrapper


Expand Down
Loading