Skip to content

Commit 307dbe1

Browse files
authored
feat: add fix to watsonx and note to litellm (#173)
* fix: watsonx client creation issue * fix: add litellm message
1 parent a4b6d27 commit 307dbe1

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

mellea/backends/litellm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def __init__(
5353
):
5454
"""Initialize and OpenAI compatible backend. For any additional kwargs that you need to pass the the client, pass them as a part of **kwargs.
5555
56+
Note: If getting `Unclosed client session`, set `export DISABLE_AIOHTTP_TRANSPORT=True` in your environment. See: https://github.com/BerriAI/litellm/issues/13251.
57+
5658
Args:
5759
model_id : The LiteLLM model identifier. Make sure that all necessary credentials are in OS environment variables.
5860
formatter: A custom formatter based on backend.If None, defaults to TemplateFormatter

mellea/backends/watsonx.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,15 @@ def __init__(
9090
if api_key is None:
9191
api_key = os.environ.get("WATSONX_API_KEY")
9292
if project_id is None:
93-
project_id = os.environ.get("WATSONX_PROJECT_ID")
93+
self._project_id = os.environ.get("WATSONX_PROJECT_ID")
9494

9595
self._creds = Credentials(url=base_url, api_key=api_key)
9696
_client = APIClient(credentials=self._creds)
9797
self._model_inference = ModelInference(
9898
model_id=self._get_watsonx_model_id(),
9999
api_client=_client,
100100
credentials=self._creds,
101-
project_id=project_id,
101+
project_id=self._project_id,
102102
params=self.model_options,
103103
**kwargs,
104104
)
@@ -135,7 +135,14 @@ def __init__(
135135
@property
136136
def _model(self) -> ModelInference:
137137
"""Watsonx's client gets tied to a specific event loop. Reset it here."""
138-
self._model_inference.set_api_client(APIClient(self._creds))
138+
_client = APIClient(credentials=self._creds)
139+
self._model_inference = ModelInference(
140+
model_id=self._get_watsonx_model_id(),
141+
api_client=_client,
142+
credentials=self._creds,
143+
project_id=self._project_id,
144+
params=self.model_options,
145+
)
139146
return self._model_inference
140147

141148
def _get_watsonx_model_id(self) -> str:

0 commit comments

Comments
 (0)