Skip to content

Commit 161f562

Browse files
committed
升级langchain等依赖同时调整llm_config相关配置
1 parent 7723b6c commit 161f562

31 files changed

+187
-127
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ setup_test.py
1414
build
1515
*egg-info
1616
dist
17-
.ipynb_checkpoints
17+
.ipynb_checkpoints
18+
zdatafront*

muagent/chat/search_chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from langchain.callbacks import AsyncIteratorCallbackHandler
66
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
77
from langchain.prompts.chat import ChatPromptTemplate
8-
from langchain.docstore.document import Document
8+
from langchain_community.docstore.document import Document
99

1010
# from configs.model_config import (
1111
# PROMPT_TEMPLATE, SEARCH_ENGINE_TOP_K, BING_SUBSCRIPTION_KEY, BING_SEARCH_URL,

muagent/codechat/code_search/code_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from muagent.codechat.code_search.cypher_generator import CypherGenerator
1717
from muagent.codechat.code_search.tagger import Tagger
18-
from muagent.embeddings.get_embedding import get_embedding
18+
from muagent.llm_models.get_embedding import get_embedding
1919
from muagent.llm_models.llm_config import LLMConfig, EmbedConfig
2020

2121

muagent/codechat/code_search/cypher_generator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
@time: 2023/11/24 上午10:17
66
@desc:
77
'''
8-
from langchain import PromptTemplate
8+
from langchain.prompts import PromptTemplate
99
from loguru import logger
1010

1111
from muagent.llm_models.openai_model import getChatModelFromConfig
@@ -14,7 +14,8 @@
1414
from langchain.schema import (
1515
HumanMessage,
1616
)
17-
from langchain.chains.graph_qa.prompts import NGQL_GENERATION_PROMPT, CYPHER_GENERATION_TEMPLATE
17+
# from langchain.chains.graph_qa.prompts import NGQL_GENERATION_PROMPT, CYPHER_GENERATION_TEMPLATE
18+
from langchain_community.chains.graph_qa.prompts import CYPHER_GENERATION_TEMPLATE
1819

1920
schema = '''
2021
Node properties: [{'tag': 'package', 'properties': []}, {'tag': 'class', 'properties': []}, {'tag': 'method', 'properties': []}]

muagent/codechat/codebase_handler/code_importer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from muagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
1414
from muagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
15-
from muagent.embeddings.get_embedding import get_embedding
15+
from muagent.llm_models.get_embedding import get_embedding
1616
from muagent.llm_models.llm_config import EmbedConfig
1717

1818

muagent/connector/memory_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
from loguru import logger
77
import numpy as np
88

9-
from langchain.docstore.document import Document
9+
from langchain_community.docstore.document import Document
1010

1111

1212
from .schema import Memory, Message
1313
from muagent.service.service_factory import KBServiceFactory
1414
from muagent.llm_models import getChatModelFromConfig
1515
from muagent.llm_models.llm_config import EmbedConfig, LLMConfig
16-
from muagent.embeddings.utils import load_embeddings_from_path
16+
from muagent.retrieval.utils import load_embeddings_from_path
1717
from muagent.utils.common_utils import *
1818
from muagent.connector.configs.prompts import CONV_SUMMARY_PROMPT_SPEC
1919
from muagent.orm import table_init
@@ -489,7 +489,7 @@ def check_chat_index(self, chat_index: str):
489489

490490

491491
from muagent.utils.tbase_util import TbaseHandler
492-
from muagent.embeddings.get_embedding import get_embedding
492+
from muagent.llm_models.get_embedding import get_embedding
493493
from redis.commands.search.field import (
494494
TextField,
495495
NumericField,

muagent/embeddings/get_embedding.py renamed to muagent/llm_models/get_embedding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from loguru import logger
99

1010
# from configs.model_config import EMBEDDING_MODEL
11-
from muagent.embeddings.openai_embedding import OpenAIEmbedding
12-
from muagent.embeddings.huggingface_embedding import HFEmbedding
11+
from muagent.llm_models.openai_embedding import OpenAIEmbedding
12+
from muagent.llm_models.huggingface_embedding import HFEmbedding
1313
from muagent.llm_models.llm_config import EmbedConfig
1414

1515
def get_embedding(
@@ -35,7 +35,7 @@ def get_embedding(
3535
oae = OpenAIEmbedding()
3636
emb_res = oae.get_emb(text_list)
3737
elif engine == 'model':
38-
hfe = HFEmbedding(model_path, embedding_device)
38+
hfe = HFEmbedding(model_path, embed_config.model_device)
3939
emb_res = hfe.get_emb(text_list)
4040

4141
return emb_res

muagent/embeddings/huggingface_embedding.py renamed to muagent/llm_models/huggingface_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from loguru import logger
99
# from configs.model_config import EMBEDDING_DEVICE
1010
# from configs.model_config import embedding_model_dict
11-
from muagent.embeddings.utils import load_embeddings, load_embeddings_from_path
11+
from muagent.retrieval.utils import load_embeddings, load_embeddings_from_path
1212

1313

1414
class HFEmbedding:

muagent/llm_models/llm_config.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class LLMConfig:
1111
def __init__(
1212
self,
1313
model_name: str = "gpt-3.5-turbo",
14+
model_engine: str = "openai",
1415
temperature: float = 0.25,
1516
stop: Union[List[str], str] = None,
1617
api_key: str = "",
@@ -19,12 +20,15 @@ def __init__(
1920
llm: LLM = None,
2021
**kwargs
2122
):
22-
23+
# only support http connection with others
24+
# llm_model init config
2325
self.model_name: str = model_name
26+
self.model_engine: str = model_engine
2427
self.temperature: float = temperature
2528
self.stop: Union[List[str], str] = stop
2629
self.api_key: str = api_key
2730
self.api_base_url: str = api_base_url
31+
# custom llm
2832
self.llm: LLM = llm
2933
#
3034
self.check_config()
@@ -55,7 +59,7 @@ def __init__(
5559
self.model_device: str = model_device
5660
self.api_key: str = api_key
5761
self.api_base_url: str = api_base_url
58-
#
62+
# custom embeddings
5963
self.langchain_embeddings = langchain_embeddings
6064
#
6165
self.check_config()

0 commit comments

Comments
 (0)