Skip to content

Commit 59965c9

Browse files
committed
fix: list_models() issue, modelId param, docs: add aws bedrock integration docs
1 parent a5b3388 commit 59965c9

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

adalflow/adalflow/components/model_client/bedrock_client.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ def __init__(
123123
self.chat_completion_parser = (
124124
chat_completion_parser or get_first_message_content
125125
)
126+
self.inference_parameters = [
127+
"maxTokens",
128+
"temperature",
129+
"topP",
130+
"stopSequences",
131+
]
126132

127133
def init_sync_client(self):
128134
"""
@@ -233,6 +239,47 @@ def list_models(self):
233239
except Exception as e:
234240
print(f"Error listing models: {e}")
235241

242+
def _validate_and_process_model_id(self, api_kwargs: Dict):
243+
"""
244+
Validate and process the model ID in API kwargs.
245+
246+
:param api_kwargs: Dictionary of API keyword arguments
247+
:raises KeyError: If 'model' key is missing
248+
"""
249+
if "model" in api_kwargs:
250+
api_kwargs["modelId"] = api_kwargs.pop("model")
251+
else:
252+
raise KeyError("The required key 'model' is missing in model_kwargs.")
253+
254+
return api_kwargs
255+
256+
def _separate_parameters(self, api_kwargs: Dict) -> tuple:
257+
"""
258+
Separate inference configuration and additional model request fields.
259+
260+
:param api_kwargs: Dictionary of API keyword arguments
261+
:return: Tuple of (inference_config, additional_model_request_fields)
262+
"""
263+
inference_config = {}
264+
additional_model_request_fields = {}
265+
keys_to_remove = set()
266+
excluded_keys = {"modelId"}
267+
268+
# Categorize parameters
269+
for key, value in list(api_kwargs.items()):
270+
if key in self.inference_parameters:
271+
inference_config[key] = value
272+
keys_to_remove.add(key)
273+
elif key not in excluded_keys:
274+
additional_model_request_fields[key] = value
275+
keys_to_remove.add(key)
276+
277+
# Remove categorized keys from api_kwargs
278+
for key in keys_to_remove:
279+
api_kwargs.pop(key, None)
280+
281+
return api_kwargs, inference_config, additional_model_request_fields
282+
236283
def convert_inputs_to_api_kwargs(
237284
self,
238285
input: Optional[Any] = None,
@@ -245,9 +292,17 @@ def convert_inputs_to_api_kwargs(
245292
"""
246293
api_kwargs = model_kwargs.copy()
247294
if model_type == ModelType.LLM:
295+
# Validate and process model ID
296+
api_kwargs = self._validate_and_process_model_id(api_kwargs)
297+
298+
# Separate inference config and additional model request fields
299+
api_kwargs, inference_config, additional_model_request_fields = self._separate_parameters(api_kwargs)
300+
248301
api_kwargs["messages"] = [
249302
{"role": "user", "content": [{"text": input}]},
250303
]
304+
api_kwargs["inferenceConfig"] = inference_config
305+
api_kwargs["additionalModelRequestFields"] = additional_model_request_fields
251306
else:
252307
raise ValueError(f"Model type {model_type} not supported")
253308
return api_kwargs

0 commit comments

Comments
 (0)