Skip to content

Commit be953d9

Browse files
committed
fix: list_models() issue, modelId param, docs: add aws bedrock integration docs
1 parent ead370d commit be953d9

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

adalflow/adalflow/components/model_client/bedrock_client.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,12 @@ def __init__(
118118
self.chat_completion_parser = (
119119
chat_completion_parser or get_first_message_content
120120
)
121+
self.inference_parameters = [
122+
"maxTokens",
123+
"temperature",
124+
"topP",
125+
"stopSequences",
126+
]
121127

122128
def init_sync_client(self):
123129
"""
@@ -192,6 +198,47 @@ def list_models(self):
192198
except Exception as e:
193199
print(f"Error listing models: {e}")
194200

201+
def _validate_and_process_model_id(self, api_kwargs: Dict):
202+
"""
203+
Validate and process the model ID in API kwargs.
204+
205+
:param api_kwargs: Dictionary of API keyword arguments
206+
:raises KeyError: If 'model' key is missing
207+
"""
208+
if "model" in api_kwargs:
209+
api_kwargs["modelId"] = api_kwargs.pop("model")
210+
else:
211+
raise KeyError("The required key 'model' is missing in model_kwargs.")
212+
213+
return api_kwargs
214+
215+
def _separate_parameters(self, api_kwargs: Dict) -> tuple:
216+
"""
217+
Separate inference configuration and additional model request fields.
218+
219+
:param api_kwargs: Dictionary of API keyword arguments
220+
:return: Tuple of (inference_config, additional_model_request_fields)
221+
"""
222+
inference_config = {}
223+
additional_model_request_fields = {}
224+
keys_to_remove = set()
225+
excluded_keys = {"modelId"}
226+
227+
# Categorize parameters
228+
for key, value in list(api_kwargs.items()):
229+
if key in self.inference_parameters:
230+
inference_config[key] = value
231+
keys_to_remove.add(key)
232+
elif key not in excluded_keys:
233+
additional_model_request_fields[key] = value
234+
keys_to_remove.add(key)
235+
236+
# Remove categorized keys from api_kwargs
237+
for key in keys_to_remove:
238+
api_kwargs.pop(key, None)
239+
240+
return api_kwargs, inference_config, additional_model_request_fields
241+
195242
def convert_inputs_to_api_kwargs(
196243
self,
197244
input: Optional[Any] = None,
@@ -204,9 +251,17 @@ def convert_inputs_to_api_kwargs(
204251
"""
205252
api_kwargs = model_kwargs.copy()
206253
if model_type == ModelType.LLM:
254+
# Validate and process model ID
255+
api_kwargs = self._validate_and_process_model_id(api_kwargs)
256+
257+
# Separate inference config and additional model request fields
258+
api_kwargs, inference_config, additional_model_request_fields = self._separate_parameters(api_kwargs)
259+
207260
api_kwargs["messages"] = [
208261
{"role": "user", "content": [{"text": input}]},
209262
]
263+
api_kwargs["inferenceConfig"] = inference_config
264+
api_kwargs["additionalModelRequestFields"] = additional_model_request_fields
210265
else:
211266
raise ValueError(f"Model type {model_type} not supported")
212267
return api_kwargs

0 commit comments

Comments
 (0)