Skip to content

Commit cc1887f

Browse files
committed
feat: accept kwargs in list_model(), fix: max_tokens parameter
1 parent 59965c9 commit cc1887f

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

adalflow/adalflow/components/model_client/bedrock_client.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,15 +221,16 @@ def track_completion_usage(self, completion: Dict) -> CompletionUsage:
221221
total_tokens=usage["totalTokens"],
222222
)
223223

224-
def list_models(self):
224+
def list_models(self, **kwargs):
225225
# Initialize Bedrock client (not runtime)
226226

227227
try:
228-
response = self._client.list_foundation_models()
228+
response = self._client.list_foundation_models(**kwargs)
229229
models = response.get("modelSummaries", [])
230230
for model in models:
231231
print(f"Model ID: {model['modelId']}")
232232
print(f" Name: {model['modelName']}")
233+
print(f" Model ARN: {model['modelArn']}")
233234
print(f" Provider: {model['providerName']}")
234235
print(f" Input: {model['inputModalities']}")
235236
print(f" Output: {model['outputModalities']}")
@@ -239,7 +240,7 @@ def list_models(self):
239240
except Exception as e:
240241
print(f"Error listing models: {e}")
241242

242-
def _validate_and_process_model_id(self, api_kwargs: Dict):
243+
def _validate_and_process_config_keys(self, api_kwargs: Dict):
243244
"""
244245
Validate and process the model ID in API kwargs.
245246
@@ -251,6 +252,10 @@ def _validate_and_process_model_id(self, api_kwargs: Dict):
251252
else:
252253
raise KeyError("The required key 'model' is missing in model_kwargs.")
253254

255+
# In .converse() `maxTokens`` is the key for maximum tokens limit
256+
if "max_tokens" in api_kwargs:
257+
api_kwargs["maxTokens"] = api_kwargs.pop("max_tokens")
258+
254259
return api_kwargs
255260

256261
def _separate_parameters(self, api_kwargs: Dict) -> tuple:
@@ -293,7 +298,7 @@ def convert_inputs_to_api_kwargs(
293298
api_kwargs = model_kwargs.copy()
294299
if model_type == ModelType.LLM:
295300
# Validate and process model ID
296-
api_kwargs = self._validate_and_process_model_id(api_kwargs)
301+
api_kwargs = self._validate_and_process_config_keys(api_kwargs)
297302

298303
# Separate inference config and additional model request fields
299304
api_kwargs, inference_config, additional_model_request_fields = self._separate_parameters(api_kwargs)

0 commit comments

Comments
 (0)