11import argparse
22
3- from langchain_core .language_models import BaseLLM , BaseChatModel
3+ from langchain_core .language_models import BaseChatModel , BaseLLM
44from langchain_core .language_models .base import BaseLanguageModel
55
66# important note: if you import these after patching, the patch won't apply!
1414
1515def parse_args ():
1616 parser = argparse .ArgumentParser (description = "LangChain model comparison" )
17- parser .add_argument ("--provider" , choices = ["ollama" , "openai" , "huggingface" ], default = "ollama" ,
18- help = "Choose model provider (default: ollama)" )
17+ parser .add_argument (
18+ "--provider" ,
19+ choices = ["ollama" , "openai" , "huggingface" ],
20+ default = "ollama" ,
21+ help = "Choose model provider (default: ollama)" ,
22+ )
1923 parser .add_argument ("--model" , type = str , help = "Specify model name" )
20- parser .add_argument ("--prompt" , type = str , default = "What is the capital of France?" ,
21- help = "Input prompt" )
24+ parser .add_argument (
25+ "--prompt" ,
26+ type = str ,
27+ default = "What is the capital of France?" ,
28+ help = "Input prompt" ,
29+ )
2230
2331 return parser .parse_args ()
2432
2533
2634def chat_with_model (model : BaseLanguageModel , prompt : str ) -> str :
2735 try :
2836 response = model .invoke (prompt )
29- if hasattr (response , ' content' ):
37+ if hasattr (response , " content" ):
3038 return response .content
3139 else :
3240 return str (response )
@@ -35,24 +43,15 @@ def chat_with_model(model: BaseLanguageModel, prompt: str) -> str:
3543
3644
3745def create_huggingface_model (model : str = "google/flan-t5-small" ):
38- return HuggingFaceEndpoint (
39- repo_id = model ,
40- temperature = 0.7
41- )
46+ return HuggingFaceEndpoint (repo_id = model , temperature = 0.7 )
4247
4348
4449def create_openai_model (model : str = "gpt-3.5-turbo" ):
45- return ChatOpenAI (
46- model = model ,
47- temperature = 0.7
48- )
50+ return ChatOpenAI (model = model , temperature = 0.7 )
4951
5052
5153def create_ollama_model (model : str = "llama2" ):
52- return OllamaLLM (
53- model = model ,
54- temperature = 0.7
55- )
54+ return OllamaLLM (model = model , temperature = 0.7 )
5655
5756
5857def patch_llm ():
0 commit comments