Skip to content

Commit 03a48d3

Browse files
committed
feat: accept kwargs in list_model(), fix: max_tokens parameter
1 parent be953d9 commit 03a48d3

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

adalflow/adalflow/components/model_client/bedrock_client.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,15 +181,16 @@ def track_completion_usage(self, completion: Dict) -> CompletionUsage:
181181
total_tokens=usage["totalTokens"],
182182
)
183183

184-
def list_models(self):
184+
def list_models(self, **kwargs):
185185
# Initialize Bedrock client (not runtime)
186186

187187
try:
188-
response = self._client.list_foundation_models()
188+
response = self._client.list_foundation_models(**kwargs)
189189
models = response.get("modelSummaries", [])
190190
for model in models:
191191
print(f"Model ID: {model['modelId']}")
192192
print(f" Name: {model['modelName']}")
193+
print(f" Model ARN: {model['modelArn']}")
193194
print(f" Provider: {model['providerName']}")
194195
print(f" Input: {model['inputModalities']}")
195196
print(f" Output: {model['outputModalities']}")
@@ -198,7 +199,7 @@ def list_models(self):
198199
except Exception as e:
199200
print(f"Error listing models: {e}")
200201

201-
def _validate_and_process_model_id(self, api_kwargs: Dict):
202+
def _validate_and_process_config_keys(self, api_kwargs: Dict):
202203
"""
203204
Validate and process the model ID in API kwargs.
204205
@@ -210,6 +211,10 @@ def _validate_and_process_model_id(self, api_kwargs: Dict):
210211
else:
211212
raise KeyError("The required key 'model' is missing in model_kwargs.")
212213

214+
# In .converse() `maxTokens`` is the key for maximum tokens limit
215+
if "max_tokens" in api_kwargs:
216+
api_kwargs["maxTokens"] = api_kwargs.pop("max_tokens")
217+
213218
return api_kwargs
214219

215220
def _separate_parameters(self, api_kwargs: Dict) -> tuple:
@@ -252,7 +257,7 @@ def convert_inputs_to_api_kwargs(
252257
api_kwargs = model_kwargs.copy()
253258
if model_type == ModelType.LLM:
254259
# Validate and process model ID
255-
api_kwargs = self._validate_and_process_model_id(api_kwargs)
260+
api_kwargs = self._validate_and_process_config_keys(api_kwargs)
256261

257262
# Separate inference config and additional model request fields
258263
api_kwargs, inference_config, additional_model_request_fields = self._separate_parameters(api_kwargs)

0 commit comments

Comments
 (0)