44from langchain_openai import ChatOpenAI
55
66from pydantic import BaseModel , Field
7- from typing import Type , Optional , Dict , Any , List
7+ from typing import Type , Optional , Dict , Any , List , Literal
8+ import logging
89
910from src .core .ports .secondary .template_service import TemplateService
1011from src .core .agents .utils .state import AgentState
11- from src .core .domain .company_search import CompanySearchResponse
12+ from src .core .domain .company_search import CompanyInfo
1213
14+ logger = logging .getLogger ("core.agents.search_agents" )
1315
1416def get_search_tool (max_results = 3 , time_range = "month" ):
1517 """
@@ -30,7 +32,7 @@ def get_search_tool(max_results=3, time_range="month"):
3032 topic = "general" , # Use "general" for company information
3133 )
3234
33- async def create_company_search_agent (response_format : Type [BaseModel ],
35+ def create_company_search_agent (response_format : Type [BaseModel ],
3436 template_service : TemplateService ,
3537 model_name = "gpt-3.5-turbo" ):
3638 """
@@ -65,25 +67,26 @@ async def create_company_search_agent(response_format: Type[BaseModel],
6567
6668async def search_company_info (company_name : str ,
6769 template_service : TemplateService ,
68- model_name = "gpt-3.5-turbo" ) -> Optional [CompanySearchResponse ]:
70+ model_name = "gpt-3.5-turbo" ) -> Optional [CompanyInfo ]:
6971 """
7072 Search for up-to-date information about a company and return structured data.
7173 """
74+ logger .info (f"Collecting more details on { company_name } " )
7275 # TODO: Retry on OpenAI rate limit errors
7376 search_query = template_service .render_prompt (
7477 "prompts/company_search/search_query.j2" ,
75- ** {"company_name" : company_name , "search_result_format" : CompanySearchResponse .model_json_schema ()}
78+ ** {"company_name" : company_name , "search_result_format" : CompanyInfo .model_json_schema ()}
7679 )
7780
7881 try :
79- agent = await create_company_search_agent (CompanySearchResponse , template_service , model_name )
82+ agent = create_company_search_agent (CompanyInfo , template_service , model_name )
8083 response = await agent .ainvoke (
8184 {"messages" : [{"role" : "user" , "content" : search_query }]}
8285 )
8386
8487 if "structured_response" in response :
8588 result = response ["structured_response" ]
86- return CompanySearchResponse .model_validate (result )
89+ return CompanyInfo .model_validate (result )
8790 else :
8891 print ("No structured response found in agent output" )
8992 return None
@@ -92,23 +95,39 @@ async def search_company_info(company_name: str,
9295 return None
9396
9497
95- async def main ():
98+
99+
100+ if __name__ == "__main__" :
101+ import asyncio
96102 from src .core .domain .config import TemplateConfig
97103 from src .infrastructure .template .jinja_template_service import JinjaTemplateService
98104 template_config = TemplateConfig .development ()
99105 template_service = JinjaTemplateService (config = template_config )
100106
101- company_info = await search_company_info ("Apple Inc." , template_service , model_name = "gpt-4o" )
102-
103- if company_info :
104- print ("\n Company Information:" )
105- print (f"Name: { company_info .company_name } " )
106- print (f"Industry: { company_info .company_industry } " )
107- print (f"Size: { company_info .company_size } employees" )
108- print (f"Revenue: { company_info .company_revenue } " )
109- print (f"Location: { company_info .company_location } " )
110- print (f"Website: { company_info .company_website } " )
111- print (f"Founded: { company_info .founded_year or 'Unknown' } " )
112- print (f"Description: { company_info .company_description } " )
113- else :
114- print ("Could not retrieve company information." )
107+ async def main ():
108+ companies = ["Apple Inc." , "Microsoft" , "Google" , "A non-existent company" ,
109+ "NVIDIA" , "Tesla" , "Amazon" , "Meta" ]
110+ concurrency_limit = 5
111+
112+ tasks = [search_company_info (company_name = name , template_service = template_service ,
113+ model_name = "gpt-4.1-mini" ) for name in companies ]
114+ results = await asyncio .gather (* tasks )
115+
116+ company_info_map = {info .name : info for info in results if info }
117+
118+ for company_name in companies :
119+ company_info = company_info_map .get (company_name )
120+ if company_info :
121+ print ("\n Company Information:" )
122+ print (f" Name: { company_info .name } " )
123+ print (f" Industry: { company_info .industry } " )
124+ print (f" Size: { company_info .size } employees" )
125+ print (f" Revenue: { company_info .revenue } " )
126+ print (f" Location: { company_info .location } " )
127+ print (f" Website: { company_info .website } " )
128+ print (f" Founded: { company_info .founded_year or 'Unknown' } " )
129+ print (f" Description: { company_info .description } " )
130+ else :
131+ print (f"\n Could not retrieve company information for { company_name } ." )
132+
133+ asyncio .run (main ())
0 commit comments