Skip to content

Commit 33da607

Browse files
committed
feat: implement AWS Bedrock Vision-Language and Reranker models with credential validation
1 parent 12bc39f commit 33da607

File tree

6 files changed

+367
-3
lines changed

6 files changed

+367
-3
lines changed

apps/common/handle/impl/response/openai_to_response.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121

2222
class OpenaiToResponse(BaseToResponse):
23-
def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens,
23+
def to_block_response(self, chat_id, chat_record_id, content, is_end, prompt_tokens, completion_tokens,
2424
other_params: dict = None,
2525
_status=status.HTTP_200_OK):
2626
if other_params is None:
@@ -37,8 +37,8 @@ def to_block_response(self, chat_id, chat_record_id, content, is_end, completion
3737
return JsonResponse(data=data, status=_status)
3838

3939
def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end,
40-
completion_tokens,
41-
prompt_tokens, other_params: dict = None):
40+
prompt_tokens,
41+
completion_tokens, other_params: dict = None):
4242
if other_params is None:
4343
other_params = {}
4444
chunk = ChatCompletionChunk(id=chat_record_id, model='', object='chat.completion.chunk',

apps/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,17 @@
77
IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
88
)
99
from models_provider.impl.aws_bedrock_model_provider.credential.embedding import BedrockEmbeddingCredential
10+
from models_provider.impl.aws_bedrock_model_provider.credential.image import BedrockVLModelCredential
1011
from models_provider.impl.aws_bedrock_model_provider.credential.llm import BedrockLLMModelCredential
12+
from models_provider.impl.aws_bedrock_model_provider.credential.reranker import BedrockRerankerCredential
1113
from models_provider.impl.aws_bedrock_model_provider.model.embedding import BedrockEmbeddingModel
14+
from models_provider.impl.aws_bedrock_model_provider.model.image import BedrockVLModel
1215
from models_provider.impl.aws_bedrock_model_provider.model.llm import BedrockModel
1316
from maxkb.conf import PROJECT_DIR
1417
from django.utils.translation import gettext as _
1518

19+
from models_provider.impl.aws_bedrock_model_provider.model.reranker import BedrockRerankerModel
20+
1621

1722
def _create_model_info(model_name, description, model_type, credential_class, model_class):
1823
return ModelInfo(
@@ -127,11 +132,56 @@ def _initialize_model_info():
127132
),
128133
]
129134

135+
reranker_model_info_list = [
136+
_create_model_info(
137+
'amazon.rerank-v1:0',
138+
'',
139+
ModelTypeConst.RERANKER,
140+
BedrockRerankerCredential,
141+
BedrockRerankerModel
142+
),
143+
_create_model_info(
144+
'cohere.rerank-v3-5:0',
145+
'',
146+
ModelTypeConst.RERANKER,
147+
BedrockRerankerCredential,
148+
BedrockRerankerModel
149+
)
150+
]
151+
vl_model_info_list = [
152+
153+
_create_model_info(
154+
'global.anthropic.claude-sonnet-4-5-20250929-v1:0',
155+
'',
156+
ModelTypeConst.IMAGE,
157+
BedrockVLModelCredential,
158+
BedrockVLModel
159+
),
160+
_create_model_info(
161+
'us.anthropic.claude-sonnet-4-5-20250929-v1:0',
162+
'',
163+
ModelTypeConst.IMAGE,
164+
BedrockVLModelCredential,
165+
BedrockVLModel
166+
),
167+
_create_model_info(
168+
'global.anthropic.claude-haiku-4-5-20251001-v1:0',
169+
'',
170+
ModelTypeConst.IMAGE,
171+
BedrockVLModelCredential,
172+
BedrockVLModel
173+
),
174+
]
175+
130176
model_info_manage = ModelInfoManage.builder() \
131177
.append_model_info_list(model_info_list) \
132178
.append_default_model_info(model_info_list[0]) \
133179
.append_model_info_list(embedded_model_info_list) \
134180
.append_default_model_info(embedded_model_info_list[0]) \
181+
.append_model_info_list(vl_model_info_list) \
182+
.append_default_model_info(vl_model_info_list[0]) \
183+
.append_model_info_list(reranker_model_info_list) \
184+
.append_default_model_info(reranker_model_info_list[0]) \
135185
.build()
136186

137187
return model_info_manage
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from typing import Dict
2+
3+
from django.utils.translation import gettext_lazy as _, gettext
4+
from langchain_core.messages import HumanMessage
5+
6+
from common import forms
7+
from common.exception.app_exception import AppApiException
8+
from common.forms import BaseForm, TooltipLabel
9+
from models_provider.base_model_provider import ValidCode, BaseModelCredential
10+
from common.utils.logger import maxkb_logger
11+
12+
13+
class BedrockImageModelParams(BaseForm):
14+
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
15+
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
16+
required=True, default_value=0.7,
17+
_min=0.1,
18+
_max=1.0,
19+
_step=0.01,
20+
precision=2)
21+
22+
max_tokens = forms.SliderField(
23+
TooltipLabel(_('Output the maximum Tokens'),
24+
_('Specify the maximum number of tokens that the model can generate')),
25+
required=True, default_value=1024,
26+
_min=1,
27+
_max=100000,
28+
_step=1,
29+
precision=0)
30+
31+
32+
class BedrockVLModelCredential(BaseForm, BaseModelCredential):
33+
34+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
35+
raise_exception=False):
36+
model_type_list = provider.get_model_type_list()
37+
if not any(mt.get('value') == model_type for mt in model_type_list):
38+
if raise_exception:
39+
raise AppApiException(ValidCode.valid_error.value,
40+
gettext('{model_type} Model type is not supported').format(model_type=model_type))
41+
return False
42+
43+
required_keys = ['region_name', 'access_key_id', 'secret_access_key']
44+
if not all(key in model_credential for key in required_keys):
45+
if raise_exception:
46+
raise AppApiException(ValidCode.valid_error.value,
47+
gettext('The following fields are required: {keys}').format(
48+
keys=", ".join(required_keys)))
49+
return False
50+
51+
try:
52+
model = provider.get_model(model_type, model_name, model_credential, **model_params)
53+
model.invoke([HumanMessage(content=gettext('Hello'))])
54+
except AppApiException:
55+
raise
56+
except Exception as e:
57+
maxkb_logger.error(f'Exception: {e}', exc_info=True)
58+
if raise_exception:
59+
raise AppApiException(ValidCode.valid_error.value,
60+
gettext(
61+
'Verification failed, please check whether the parameters are correct: {error}').format(
62+
error=str(e)))
63+
return False
64+
65+
return True
66+
67+
def encryption_dict(self, model: Dict[str, object]):
68+
return {**model, 'secret_access_key': super().encryption(model.get('secret_access_key', ''))}
69+
70+
region_name = forms.TextInputField('Region Name', required=True)
71+
access_key_id = forms.TextInputField('Access Key ID', required=True)
72+
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)
73+
base_url = forms.TextInputField('Proxy URL', required=False)
74+
75+
def get_model_params_setting_form(self, model_name):
76+
return BedrockImageModelParams()
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import traceback
2+
from typing import Dict
3+
4+
from django.utils.translation import gettext_lazy as _, gettext
5+
from langchain_core.documents import Document
6+
7+
from common import forms
8+
from common.exception.app_exception import AppApiException
9+
from common.forms import BaseForm, TooltipLabel
10+
from models_provider.base_model_provider import ValidCode, BaseModelCredential
11+
12+
13+
class BedrockRerankerModelParams(BaseForm):
14+
top_n = forms.SliderField(TooltipLabel(_('Top N'),
15+
_('Number of top documents to return after reranking')),
16+
required=True, default_value=3,
17+
_min=1,
18+
_max=20,
19+
_step=1,
20+
precision=0)
21+
22+
23+
class BedrockRerankerCredential(BaseForm, BaseModelCredential):
24+
access_key_id = forms.PasswordInputField(_('Access Key ID'), required=True)
25+
secret_access_key = forms.PasswordInputField(_('Secret Access Key'), required=True)
26+
region_name = forms.TextInputField(_('Region Name'), required=True, default_value='us-east-1')
27+
base_url = forms.TextInputField(_('Base URL'), required=False)
28+
29+
def is_valid(self, model_type: str, model_name: str, model_credential: Dict[str, object], model_params,
30+
provider,
31+
raise_exception: bool = False):
32+
model_type_list = provider.get_model_type_list()
33+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
34+
raise AppApiException(ValidCode.valid_error.value, _('Model type is not supported'))
35+
36+
for key in ['access_key_id', 'secret_access_key', 'region_name']:
37+
if key not in model_credential:
38+
if raise_exception:
39+
raise AppApiException(ValidCode.valid_error.value, _('%(key)s is required') % {'key': key})
40+
else:
41+
return False
42+
try:
43+
model = provider.get_model(model_type, model_name, model_credential, **model_params)
44+
# Use top_n=1 for validation since we only have 1 test document
45+
test_docs = [
46+
Document(page_content=str(_('Hello'))),
47+
Document(page_content=str(_('World'))),
48+
Document(page_content=str(_('Test')))
49+
]
50+
model.compress_documents(test_docs, str(_('Hello')))
51+
except Exception as e:
52+
traceback.print_exc()
53+
if isinstance(e, AppApiException):
54+
raise e
55+
if raise_exception:
56+
raise AppApiException(ValidCode.valid_error.value,
57+
_('Verification failed, please check whether the parameters are correct: %(error)s') % {'error': str(e)})
58+
else:
59+
return False
60+
return True
61+
62+
def encryption_dict(self, model: Dict[str, object]):
63+
return {**model, 'access_key_id': super().encryption(model.get('access_key_id', '')),
64+
'secret_access_key': super().encryption(model.get('secret_access_key', ''))}
65+
66+
def get_model_params_setting_form(self, model_name):
67+
return BedrockRerankerModelParams()
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@file: image.py
5+
@desc: AWS Bedrock Vision-Language Model Implementation
6+
"""
7+
from typing import Dict, List
8+
9+
from botocore.config import Config
10+
from langchain_aws import ChatBedrock
11+
from langchain_core.messages import BaseMessage, get_buffer_string
12+
13+
from common.config.tokenizer_manage_config import TokenizerManage
14+
from models_provider.base_model_provider import MaxKBBaseModel
15+
from models_provider.impl.aws_bedrock_model_provider.model.llm import _update_aws_credentials
16+
17+
18+
class BedrockVLModel(MaxKBBaseModel, ChatBedrock):
19+
"""
20+
AWS Bedrock Vision-Language Model
21+
Supports Claude 3 models with vision capabilities (Haiku, Sonnet, Opus)
22+
"""
23+
24+
@staticmethod
25+
def is_cache_model():
26+
return False
27+
28+
def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
29+
streaming: bool = False, config: Config = None, **kwargs):
30+
super().__init__(
31+
model_id=model_id,
32+
region_name=region_name,
33+
credentials_profile_name=credentials_profile_name,
34+
streaming=streaming,
35+
config=config,
36+
**kwargs
37+
)
38+
39+
@classmethod
40+
def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
41+
**model_kwargs) -> 'BedrockVLModel':
42+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
43+
44+
config = {}
45+
# Check if proxy URL is provided
46+
if 'base_url' in model_credential and model_credential['base_url']:
47+
proxy_url = model_credential['base_url']
48+
config = Config(
49+
proxies={
50+
'http': proxy_url,
51+
'https': proxy_url
52+
},
53+
connect_timeout=60,
54+
read_timeout=60
55+
)
56+
_update_aws_credentials(
57+
model_credential['access_key_id'],
58+
model_credential['access_key_id'],
59+
model_credential['secret_access_key']
60+
)
61+
62+
return cls(
63+
model_id=model_name,
64+
region_name=model_credential['region_name'],
65+
credentials_profile_name=model_credential['access_key_id'],
66+
streaming=model_kwargs.pop('streaming', True),
67+
model_kwargs=optional_params,
68+
config=config
69+
)
70+
71+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
72+
"""
73+
Get the number of tokens from messages
74+
Falls back to local tokenizer if the model's tokenizer fails
75+
"""
76+
try:
77+
return super().get_num_tokens_from_messages(messages)
78+
except Exception as e:
79+
tokenizer = TokenizerManage.get_tokenizer()
80+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
81+
82+
def get_num_tokens(self, text: str) -> int:
83+
"""
84+
Get the number of tokens from text
85+
Falls back to local tokenizer if the model's tokenizer fails
86+
"""
87+
try:
88+
return super().get_num_tokens(text)
89+
except Exception as e:
90+
tokenizer = TokenizerManage.get_tokenizer()
91+
return len(tokenizer.encode(text))

0 commit comments

Comments
 (0)