Skip to content

Commit ee1d862

Browse files
committed
Merge PR#75
Manually merge coleam00#75 so that it plays nice with coleam00#64
1 parent 65e83a7 commit ee1d862

File tree

1 file changed

+70
-49
lines changed

1 file changed

+70
-49
lines changed

src/crawl4ai_mcp.py

Lines changed: 70 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -148,49 +148,12 @@ async def crawl4ai_lifespan(server: FastMCP) -> AsyncIterator[Crawl4AIContext]:
148148
# Initialize Supabase client
149149
supabase_client = get_supabase_client()
150150

151-
# Initialize cross-encoder model for reranking if enabled
151+
# Initialize components as None - they will be loaded lazily when needed
152152
reranking_model = None
153-
if os.getenv("USE_RERANKING", "false") == "true":
154-
try:
155-
reranking_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
156-
except Exception as e:
157-
print(f"Failed to load reranking model: {e}")
158-
reranking_model = None
159-
160-
# Initialize Neo4j components if configured and enabled
161153
knowledge_validator = None
162154
repo_extractor = None
163155

164-
# Check if knowledge graph functionality is enabled
165-
knowledge_graph_enabled = os.getenv("USE_KNOWLEDGE_GRAPH", "false") == "true"
166-
167-
if knowledge_graph_enabled:
168-
neo4j_uri = os.getenv("NEO4J_URI")
169-
neo4j_user = os.getenv("NEO4J_USER")
170-
neo4j_password = os.getenv("NEO4J_PASSWORD")
171-
172-
if neo4j_uri and neo4j_user and neo4j_password:
173-
try:
174-
print("Initializing knowledge graph components...")
175-
176-
# Initialize knowledge graph validator
177-
knowledge_validator = KnowledgeGraphValidator(neo4j_uri, neo4j_user, neo4j_password)
178-
await knowledge_validator.initialize()
179-
print("✓ Knowledge graph validator initialized")
180-
181-
# Initialize repository extractor
182-
repo_extractor = DirectNeo4jExtractor(neo4j_uri, neo4j_user, neo4j_password)
183-
await repo_extractor.initialize()
184-
print("✓ Repository extractor initialized")
185-
186-
except Exception as e:
187-
print(f"Failed to initialize Neo4j components: {format_neo4j_error(e)}")
188-
knowledge_validator = None
189-
repo_extractor = None
190-
else:
191-
print("Neo4j credentials not configured - knowledge graph tools will be unavailable")
192-
else:
193-
print("Knowledge graph functionality disabled - set USE_KNOWLEDGE_GRAPH=true to enable")
156+
print("✓ Basic components initialized (lazy loading enabled for heavy components)")
194157

195158
try:
196159
yield Crawl4AIContext(
@@ -216,6 +179,58 @@ async def crawl4ai_lifespan(server: FastMCP) -> AsyncIterator[Crawl4AIContext]:
216179
except Exception as e:
217180
print(f"Error closing repository extractor: {e}")
218181

182+
# Add lazy loading functions
183+
async def get_reranking_model(ctx: Crawl4AIContext) -> Optional[CrossEncoder]:
184+
"""Lazy load the reranking model only when needed."""
185+
if ctx.reranking_model is None and os.getenv("USE_RERANKING", "false") == "true":
186+
try:
187+
print("Loading reranking model...")
188+
ctx.reranking_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
189+
print("✓ Reranking model loaded")
190+
except Exception as e:
191+
print(f"Failed to load reranking model: {e}")
192+
ctx.reranking_model = None
193+
return ctx.reranking_model
194+
195+
async def get_knowledge_validator(ctx: Crawl4AIContext) -> Optional[Any]:
196+
"""Lazy load the knowledge graph validator only when needed."""
197+
if ctx.knowledge_validator is None and os.getenv("USE_KNOWLEDGE_GRAPH", "false") == "true":
198+
neo4j_uri = os.getenv("NEO4J_URI")
199+
neo4j_user = os.getenv("NEO4J_USER")
200+
neo4j_password = os.getenv("NEO4J_PASSWORD")
201+
202+
if neo4j_uri and neo4j_user and neo4j_password:
203+
try:
204+
print("Loading knowledge graph validator...")
205+
ctx.knowledge_validator = KnowledgeGraphValidator(neo4j_uri, neo4j_user, neo4j_password)
206+
await ctx.knowledge_validator.initialize()
207+
print("✓ Knowledge graph validator loaded")
208+
except Exception as e:
209+
print(f"Failed to initialize knowledge validator: {format_neo4j_error(e)}")
210+
ctx.knowledge_validator = None
211+
return ctx.knowledge_validator
212+
213+
async def get_repo_extractor(ctx: Optional[Crawl4AIContext]) -> Optional[Any]:
214+
"""Lazy load the repository extractor only when needed."""
215+
if ctx is None:
216+
return None
217+
218+
if ctx.repo_extractor is None and os.getenv("USE_KNOWLEDGE_GRAPH", "false") == "true":
219+
neo4j_uri = os.getenv("NEO4J_URI")
220+
neo4j_user = os.getenv("NEO4J_USER")
221+
neo4j_password = os.getenv("NEO4J_PASSWORD")
222+
223+
if neo4j_uri and neo4j_user and neo4j_password:
224+
try:
225+
print("Loading repository extractor...")
226+
ctx.repo_extractor = DirectNeo4jExtractor(neo4j_uri, neo4j_user, neo4j_password)
227+
await ctx.repo_extractor.initialize()
228+
print("✓ Repository extractor loaded")
229+
except Exception as e:
230+
print(f"Failed to initialize repository extractor: {format_neo4j_error(e)}")
231+
ctx.repo_extractor = None
232+
return ctx.repo_extractor
233+
219234
# Initialize FastMCP server
220235
mcp = FastMCP(
221236
"mcp-crawl4ai-rag",
@@ -880,9 +895,12 @@ async def perform_rag_query(ctx: Context, query: str, source: str = None, match_
880895

881896
# Apply reranking if enabled
882897
use_reranking = os.getenv("USE_RERANKING", "false") == "true"
883-
if use_reranking and ctx.request_context.lifespan_context.reranking_model:
884-
results = rerank_results(ctx.request_context.lifespan_context.reranking_model, query, results, content_key="content")
885-
898+
reranking_model = None
899+
if use_reranking:
900+
reranking_model = await get_reranking_model(ctx.request_context.lifespan_context)
901+
if reranking_model:
902+
results = rerank_results(reranking_model, query, results, content_key="content")
903+
886904
# Format the results
887905
formatted_results = []
888906
for result in results:
@@ -902,7 +920,7 @@ async def perform_rag_query(ctx: Context, query: str, source: str = None, match_
902920
"query": query,
903921
"source_filter": source,
904922
"search_mode": "hybrid" if use_hybrid_search else "vector",
905-
"reranking_applied": use_reranking and ctx.request_context.lifespan_context.reranking_model is not None,
923+
"reranking_applied": use_reranking and reranking_model is not None,
906924
"results": formatted_results,
907925
"count": len(formatted_results)
908926
}, indent=2)
@@ -1035,8 +1053,11 @@ async def search_code_examples(ctx: Context, query: str, source_id: str = None,
10351053

10361054
# Apply reranking if enabled
10371055
use_reranking = os.getenv("USE_RERANKING", "false") == "true"
1038-
if use_reranking and ctx.request_context.lifespan_context.reranking_model:
1039-
results = rerank_results(ctx.request_context.lifespan_context.reranking_model, query, results, content_key="content")
1056+
reranking_model = None
1057+
if use_reranking:
1058+
reranking_model = await get_reranking_model(ctx.request_context.lifespan_context)
1059+
if reranking_model:
1060+
results = rerank_results(reranking_model, query, results, content_key="content")
10401061

10411062
# Format the results
10421063
formatted_results = []
@@ -1059,7 +1080,7 @@ async def search_code_examples(ctx: Context, query: str, source_id: str = None,
10591080
"query": query,
10601081
"source_filter": source_id,
10611082
"search_mode": "hybrid" if use_hybrid_search else "vector",
1062-
"reranking_applied": use_reranking and ctx.request_context.lifespan_context.reranking_model is not None,
1083+
"reranking_applied": use_reranking and reranking_model is not None,
10631084
"results": formatted_results,
10641085
"count": len(formatted_results)
10651086
}, indent=2)
@@ -1103,7 +1124,7 @@ async def check_ai_script_hallucinations(ctx: Context, script_path: str) -> str:
11031124
}, indent=2)
11041125

11051126
# Get the knowledge validator from context
1106-
knowledge_validator = ctx.request_context.lifespan_context.knowledge_validator
1127+
knowledge_validator = await get_knowledge_validator(ctx.request_context.lifespan_context)
11071128

11081129
if not knowledge_validator:
11091130
return json.dumps({
@@ -1242,7 +1263,7 @@ async def query_knowledge_graph(ctx: Context, command: str) -> str:
12421263
}, indent=2)
12431264

12441265
# Get Neo4j driver from context
1245-
repo_extractor = ctx.request_context.lifespan_context.repo_extractor
1266+
repo_extractor = await get_repo_extractor(ctx.request_context.lifespan_context)
12461267
if not repo_extractor or not repo_extractor.driver:
12471268
return json.dumps({
12481269
"success": False,
@@ -1681,7 +1702,7 @@ async def _analyze_and_store_repository(ctx: Context, repo_identifier: str, is_l
16811702
}
16821703

16831704
# Get the repository extractor from context
1684-
repo_extractor = getattr(getattr(ctx, 'request_context', None), 'lifespan_context', None) and ctx.request_context.lifespan_context.repo_extractor
1705+
repo_extractor = await get_repo_extractor(getattr(getattr(ctx, 'request_context', None), 'lifespan_context', None)) if getattr(getattr(ctx, 'request_context', None), 'lifespan_context', None) else None
16851706

16861707
if not repo_extractor:
16871708
return {

0 commit comments

Comments
 (0)