Skip to content

Commit 85868c1

Browse files
committed
fix: aws credentials_profile_name
1 parent de97de6 commit 85868c1

File tree

5 files changed

+27
-40
lines changed

5 files changed

+27
-40
lines changed

apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,6 @@
1111

1212
class BedrockEmbeddingCredential(BaseForm, BaseModelCredential):
1313

14-
@staticmethod
15-
def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
16-
credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
17-
os.makedirs(os.path.dirname(credentials_path), exist_ok=True)
18-
19-
content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else ''
20-
pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*'
21-
content = re.sub(pattern, '', content, flags=re.DOTALL)
22-
23-
if not re.search(rf'\[{profile_name}\]', content):
24-
content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n"
25-
26-
with open(credentials_path, 'w') as file:
27-
file.write(content)
28-
2914
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
3015
raise_exception=False):
3116
model_type_list = provider.get_model_type_list()
@@ -41,9 +26,6 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
4126
return False
4227

4328
try:
44-
self._update_aws_credentials('aws-profile', model_credential['access_key_id'],
45-
model_credential['secret_access_key'])
46-
model_credential['credentials_profile_name'] = 'aws-profile'
4729
model: BedrockEmbeddingModel = provider.get_model(model_type, model_name, model_credential)
4830
aa = model.embed_query('你好')
4931
print(aa)

apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import os
2-
import re
1+
32
from typing import Dict
43

54
from langchain_core.messages import HumanMessage
@@ -29,20 +28,7 @@ class BedrockLLMModelParams(BaseForm):
2928

3029
class BedrockLLMModelCredential(BaseForm, BaseModelCredential):
3130

32-
@staticmethod
33-
def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
34-
credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
35-
os.makedirs(os.path.dirname(credentials_path), exist_ok=True)
36-
37-
content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else ''
38-
pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*'
39-
content = re.sub(pattern, '', content, flags=re.DOTALL)
40-
41-
if not re.search(rf'\[{profile_name}\]', content):
42-
content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n"
4331

44-
with open(credentials_path, 'w') as file:
45-
file.write(content)
4632

4733
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
4834
raise_exception=False):
@@ -59,9 +45,6 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
5945
return False
6046

6147
try:
62-
self._update_aws_credentials('aws-profile', model_credential['access_key_id'],
63-
model_credential['secret_access_key'])
64-
model_credential['credentials_profile_name'] = 'aws-profile'
6548
model = provider.get_model(model_type, model_name, model_credential, **model_params)
6649
model.invoke([HumanMessage(content='你好')])
6750
except AppApiException:

apps/setting/models_provider/impl/aws_bedrock_model_provider/model/embedding.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from setting.models_provider.base_model_provider import MaxKBBaseModel
44
from typing import Dict, List
55

6+
from setting.models_provider.impl.aws_bedrock_model_provider.model.llm import _update_aws_credentials
7+
68

79
class BedrockEmbeddingModel(MaxKBBaseModel, BedrockEmbeddings):
810
def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
@@ -13,10 +15,12 @@ def __init__(self, model_id: str, region_name: str, credentials_profile_name: st
1315
@classmethod
1416
def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
1517
**model_kwargs) -> 'BedrockModel':
18+
_update_aws_credentials(model_credential['access_key_id'], model_credential['access_key_id'],
19+
model_credential['secret_access_key'])
1620
return cls(
1721
model_id=model_name,
1822
region_name=model_credential['region_name'],
19-
credentials_profile_name=model_credential['credentials_profile_name'],
23+
credentials_profile_name=model_credential['access_key_id'],
2024
)
2125

2226
def embed_documents(self, texts: List[str]) -> List[List[float]]:

apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import List, Dict
2-
2+
import os
3+
import re
34
from botocore.config import Config
45
from langchain_community.chat_models import BedrockChat
56
from setting.models_provider.base_model_provider import MaxKBBaseModel
@@ -57,12 +58,29 @@ def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[s
5758
connect_timeout=60,
5859
read_timeout=60
5960
)
61+
_update_aws_credentials(model_credential['access_key_id'], model_credential['access_key_id'],
62+
model_credential['secret_access_key'])
6063

6164
return cls(
6265
model_id=model_name,
6366
region_name=model_credential['region_name'],
64-
credentials_profile_name=model_credential['credentials_profile_name'],
67+
credentials_profile_name=model_credential['access_key_id'],
6568
streaming=model_kwargs.pop('streaming', True),
6669
model_kwargs=optional_params,
6770
config=config
6871
)
72+
73+
74+
def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
75+
credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
76+
os.makedirs(os.path.dirname(credentials_path), exist_ok=True)
77+
78+
content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else ''
79+
pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*'
80+
content = re.sub(pattern, '', content, flags=re.DOTALL)
81+
82+
if not re.search(rf'\[{profile_name}\]', content):
83+
content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n"
84+
85+
with open(credentials_path, 'w') as file:
86+
file.write(content)

apps/setting/serializers/provider_serializers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def is_valid(self, model=None, raise_exception=False):
160160
source_encryption_model_credential = model_credential.encryption_dict(source_model_credential)
161161
if credential is not None:
162162
for k in source_encryption_model_credential.keys():
163-
if credential[k] == source_encryption_model_credential[k]:
163+
if k in credential and credential[k] == source_encryption_model_credential[k]:
164164
credential[k] = source_model_credential[k]
165165
return credential, model_credential, provider_handler
166166

0 commit comments

Comments
 (0)