Skip to content

Commit bf729a0

Browse files
author
Adityavardhan Agrawal
authored
Add google drive connector (#150)
1 parent 2b909e2 commit bf729a0

23 files changed

+2355
-188
lines changed

core/api.py

Lines changed: 91 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import logging
55
import uuid
6+
from contextlib import asynccontextmanager
67
from datetime import UTC, datetime, timedelta
78
from pathlib import Path
89
from typing import Any, Dict, List, Optional, Union
@@ -11,15 +12,17 @@
1112
import jwt
1213
import tomli
1314
from fastapi import Depends, FastAPI, File, Form, Header, HTTPException, UploadFile
14-
from fastapi.middleware.cors import CORSMiddleware
15+
from fastapi.middleware.cors import CORSMiddleware # Import CORSMiddleware
1516
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
17+
from starlette.middleware.sessions import SessionMiddleware
1618

1719
from core.agent import MorphikAgent
1820
from core.auth_utils import verify_token
1921
from core.cache.llama_cache_factory import LlamaCacheFactory
2022
from core.completion.litellm_completion import LiteLLMCompletionModel
2123
from core.config import get_settings
2224
from core.database.postgres_database import PostgresDatabase
25+
from core.dependencies import get_redis_pool
2326
from core.embedding.colpali_api_embedding_model import ColpaliApiEmbeddingModel
2427
from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel
2528
from core.embedding.litellm_embedding import LiteLLMEmbeddingModel
@@ -51,31 +54,82 @@
5154
from core.vector_store.pgvector_store import PGVectorStore
5255

5356
# Initialize FastAPI app
54-
app = FastAPI(title="Morphik API")
5557
logger = logging.getLogger(__name__)
5658

59+
# Global variable for redis_pool, primarily for shutdown if app.state access fails.
60+
_global_redis_pool: Optional[arq.ArqRedis] = None
5761

58-
# Add health check endpoints
59-
@app.get("/health")
60-
async def health_check():
61-
"""Basic health check endpoint."""
62-
return {"status": "healthy"}
63-
64-
65-
@app.get("/health/ready")
66-
async def readiness_check():
67-
"""Readiness check that verifies the application is initialized."""
68-
return {
69-
"status": "ready",
70-
"components": {
71-
"database": settings.DATABASE_PROVIDER,
72-
"vector_store": settings.VECTOR_STORE_PROVIDER,
73-
"embedding": settings.EMBEDDING_PROVIDER,
74-
"completion": settings.COMPLETION_PROVIDER,
75-
"storage": settings.STORAGE_PROVIDER,
76-
},
77-
}
7862

63+
@asynccontextmanager
64+
async def lifespan(app_instance: FastAPI):
65+
# --- BEGIN MOVED STARTUP LOGIC ---
66+
logger.info("Lifespan: Initializing Database...")
67+
try:
68+
success = await database.initialize()
69+
if success:
70+
logger.info("Lifespan: Database initialization successful")
71+
else:
72+
logger.error("Lifespan: Database initialization failed")
73+
except Exception as e:
74+
logger.error(f"Lifespan: CRITICAL - Failed to initialize Database: {e}", exc_info=True)
75+
raise # Or handle more gracefully if appropriate
76+
77+
logger.info("Lifespan: Initializing Vector Store...")
78+
try:
79+
global vector_store
80+
if hasattr(vector_store, "initialize"):
81+
await vector_store.initialize()
82+
logger.info("Lifespan: Vector Store initialization successful (or not applicable).")
83+
except Exception as e:
84+
logger.error(f"Lifespan: CRITICAL - Failed to initialize Vector Store: {e}", exc_info=True)
85+
# Decide if this is fatal
86+
# raise
87+
88+
logger.info("Lifespan: Attempting to initialize Redis connection pool...")
89+
global _global_redis_pool # Ensure we're using the global for assignment
90+
try:
91+
redis_settings_obj = arq.connections.RedisSettings(
92+
host=settings.REDIS_HOST,
93+
port=settings.REDIS_PORT,
94+
)
95+
logger.info(f"Lifespan: Redis settings for pool: host={settings.REDIS_HOST}, port={settings.REDIS_PORT}")
96+
current_redis_pool = await arq.create_pool(redis_settings_obj)
97+
if current_redis_pool:
98+
app_instance.state.redis_pool = current_redis_pool # Use app_instance from lifespan
99+
_global_redis_pool = current_redis_pool
100+
logger.info("Lifespan: Successfully initialized Redis connection pool and stored on app.state.")
101+
else:
102+
logger.error("Lifespan: arq.create_pool returned None or a falsey value for Redis pool.")
103+
raise RuntimeError("Lifespan: Failed to create Redis pool - arq.create_pool returned non-truthy value.")
104+
except Exception as e:
105+
logger.error(f"Lifespan: CRITICAL - Failed to initialize Redis connection pool: {e}", exc_info=True)
106+
raise RuntimeError(f"Lifespan: CRITICAL - Failed to initialize Redis connection pool: {e}") from e
107+
# --- END MOVED STARTUP LOGIC ---
108+
109+
logger.info("Lifespan: Core startup logic executed.")
110+
yield
111+
# Shutdown logic
112+
logger.info("Lifespan: Shutdown initiated.")
113+
# Use app_instance.state to get the pool for shutdown too
114+
pool_to_close = getattr(app_instance.state, "redis_pool", _global_redis_pool)
115+
if pool_to_close:
116+
logger.info("Closing Redis connection pool from lifespan...")
117+
pool_to_close.close()
118+
await pool_to_close.wait_closed()
119+
logger.info("Redis connection pool closed from lifespan.")
120+
logger.info("Lifespan: Shutdown complete.")
121+
122+
123+
app = FastAPI(lifespan=lifespan)
124+
125+
# Add CORSMiddleware
126+
app.add_middleware(
127+
CORSMiddleware,
128+
allow_origins=["*"], # Allows all origins
129+
allow_credentials=True,
130+
allow_methods=["*"], # Allows all methods
131+
allow_headers=["*"], # Allows all headers
132+
)
79133

80134
# Initialize telemetry
81135
telemetry = TelemetryService()
@@ -90,112 +144,17 @@ async def readiness_check():
90144
tracer_provider=None, # Use the global tracer provider
91145
)
92146

93-
# Add CORS middleware
94-
app.add_middleware(
95-
CORSMiddleware,
96-
allow_origins=["*"],
97-
allow_credentials=True,
98-
allow_methods=["*"],
99-
allow_headers=["*"],
100-
)
101-
102147
# Initialize service
103148
settings = get_settings()
104149

150+
# Add SessionMiddleware
151+
app.add_middleware(SessionMiddleware, secret_key=settings.SESSION_SECRET_KEY)
152+
105153
# Initialize database
106154
if not settings.POSTGRES_URI:
107155
raise ValueError("PostgreSQL URI is required for PostgreSQL database")
108156
database = PostgresDatabase(uri=settings.POSTGRES_URI)
109157

110-
# Redis settings already imported at top of file
111-
112-
113-
@app.on_event("startup")
114-
async def initialize_database():
115-
"""Initialize database tables and indexes on application startup."""
116-
logger.info("Initializing database...")
117-
success = await database.initialize()
118-
if success:
119-
logger.info("Database initialization successful")
120-
else:
121-
logger.error("Database initialization failed")
122-
# We don't raise an exception here to allow the app to continue starting
123-
# even if there are initialization errors
124-
125-
126-
@app.on_event("startup")
127-
async def initialize_vector_store():
128-
"""Initialize vector store tables and indexes on application startup."""
129-
# First initialize the primary vector store (PGVectorStore if using pgvector)
130-
logger.info("Initializing primary vector store...")
131-
if hasattr(vector_store, "initialize"):
132-
success = await vector_store.initialize()
133-
if success:
134-
logger.info("Primary vector store initialization successful")
135-
else:
136-
logger.error("Primary vector store initialization failed")
137-
else:
138-
logger.warning("Primary vector store does not have an initialize method")
139-
140-
# Then initialize the multivector store if enabled
141-
if settings.ENABLE_COLPALI and colpali_vector_store:
142-
logger.info("Initializing multivector store...")
143-
# Handle both synchronous and asynchronous initialize methods
144-
if hasattr(colpali_vector_store.initialize, "__awaitable__"):
145-
success = await colpali_vector_store.initialize()
146-
else:
147-
success = colpali_vector_store.initialize()
148-
149-
if success:
150-
logger.info("Multivector store initialization successful")
151-
else:
152-
logger.error("Multivector store initialization failed")
153-
154-
155-
@app.on_event("startup")
156-
async def initialize_user_limits_database():
157-
"""Initialize user service on application startup."""
158-
logger.info("Initializing user service...")
159-
if settings.MODE == "cloud":
160-
from core.database.user_limits_db import UserLimitsDatabase
161-
162-
user_limits_db = UserLimitsDatabase(uri=settings.POSTGRES_URI)
163-
await user_limits_db.initialize()
164-
165-
166-
@app.on_event("startup")
167-
async def initialize_redis_pool():
168-
"""Initialize the Redis connection pool for background tasks."""
169-
global redis_pool
170-
logger.info("Initializing Redis connection pool...")
171-
172-
# Get Redis settings from configuration
173-
redis_host = settings.REDIS_HOST
174-
redis_port = settings.REDIS_PORT
175-
176-
# Log the Redis connection details
177-
logger.info(f"Connecting to Redis at {redis_host}:{redis_port}")
178-
179-
redis_settings = arq.connections.RedisSettings(
180-
host=redis_host,
181-
port=redis_port,
182-
)
183-
184-
redis_pool = await arq.create_pool(redis_settings)
185-
logger.info("Redis connection pool initialized successfully")
186-
187-
188-
@app.on_event("shutdown")
189-
async def close_redis_pool():
190-
"""Close the Redis connection pool on application shutdown."""
191-
global redis_pool
192-
if redis_pool:
193-
logger.info("Closing Redis connection pool...")
194-
redis_pool.close()
195-
await redis_pool.wait_closed()
196-
logger.info("Redis connection pool closed")
197-
198-
199158
# Initialize vector store
200159
if not settings.POSTGRES_URI:
201160
raise ValueError("PostgreSQL URI is required for pgvector store")
@@ -279,18 +238,22 @@ async def close_redis_pool():
279238

280239
# Initialize document service with configured components
281240
document_service = DocumentService(
282-
storage=storage,
283241
database=database,
284242
vector_store=vector_store,
243+
storage=storage,
244+
parser=parser,
285245
embedding_model=embedding_model,
286246
completion_model=completion_model,
287-
parser=parser,
288-
reranker=reranker,
289247
cache_factory=cache_factory,
290-
enable_colpali=(settings.COLPALI_MODE != "off"),
248+
reranker=reranker,
249+
enable_colpali=settings.ENABLE_COLPALI,
291250
colpali_embedding_model=colpali_embedding_model,
292251
colpali_vector_store=colpali_vector_store,
293252
)
253+
# Store document_service on app.state immediately after it's created.
254+
# This must happen before _init_ee_app(app) is called.
255+
app.state.document_service = document_service
256+
logger.info("Document service initialized and stored on app.state")
294257

295258
# Initialize the MorphikAgent once to load tool definitions and avoid repeated I/O
296259
morphik_agent = MorphikAgent(document_service=document_service)
@@ -304,10 +267,13 @@ async def close_redis_pool():
304267
from ee.routers import init_app as _init_ee_app # type: ignore
305268

306269
_init_ee_app(app) # noqa: SLF001 – runtime extension
307-
logger.info("Enterprise routes mounted (ee package detected).")
308270
except ModuleNotFoundError:
309271
# Expected in OSS builds – silently ignore.
310272
logger.debug("Enterprise package not found – running in community mode.")
273+
except ImportError as e:
274+
logger.error(f"Failed to import init_app from ee.routers: {e}", exc_info=True)
275+
except Exception as e:
276+
logger.error(f"An unexpected error occurred during EE app initialization: {e}", exc_info=True)
311277

312278

313279
@app.post("/ingest/text", response_model=Document)
@@ -349,15 +315,6 @@ async def ingest_text(
349315
raise HTTPException(status_code=403, detail=str(e))
350316

351317

352-
# Redis pool for background tasks
353-
redis_pool = None
354-
355-
356-
def get_redis_pool():
357-
"""Get the global Redis connection pool for background tasks."""
358-
return redis_pool
359-
360-
361318
@app.post("/ingest/file", response_model=Document)
362319
@telemetry.track(operation_type="queue_ingest_file", metadata_resolver=telemetry.ingest_file_metadata)
363320
async def ingest_file(

core/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class Settings(BaseSettings):
1313

1414
# Environment variables
1515
JWT_SECRET_KEY: str
16+
SESSION_SECRET_KEY: str
1617
POSTGRES_URI: Optional[str] = None
1718
UNSTRUCTURED_API_KEY: Optional[str] = None
1819
AWS_ACCESS_KEY: Optional[str] = None
@@ -155,6 +156,7 @@ def get_settings() -> Settings:
155156
auth_config = {
156157
"JWT_ALGORITHM": config["auth"]["jwt_algorithm"],
157158
"JWT_SECRET_KEY": os.environ.get("JWT_SECRET_KEY", "dev-secret-key"), # Default for dev mode
159+
"SESSION_SECRET_KEY": os.environ.get("SESSION_SECRET_KEY", "super-secret-dev-session-key"),
158160
"dev_mode": config["auth"].get("dev_mode", False),
159161
"dev_entity_type": config["auth"].get("dev_entity_type", "developer"),
160162
"dev_entity_id": config["auth"].get("dev_entity_id", "dev_user"),
@@ -164,6 +166,13 @@ def get_settings() -> Settings:
164166
# Only require JWT_SECRET_KEY in non-dev mode
165167
if not auth_config["dev_mode"] and "JWT_SECRET_KEY" not in os.environ:
166168
raise ValueError("JWT_SECRET_KEY is required when dev_mode is disabled")
169+
# Also require SESSION_SECRET_KEY in non-dev mode
170+
if not auth_config["dev_mode"] and "SESSION_SECRET_KEY" not in os.environ:
171+
# Or, if we want to be more strict and always require it via ENV:
172+
# if "SESSION_SECRET_KEY" not in os.environ:
173+
# raise ValueError("SESSION_SECRET_KEY environment variable is required.")
174+
# For now, align with JWT_SECRET_KEY's dev mode leniency.
175+
pass # Dev mode has a default, production should use ENV.
167176

168177
# Load registered models if available
169178
registered_models = {}

core/dependencies.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from typing import TYPE_CHECKING
2+
3+
import arq
4+
from fastapi import Request
5+
6+
if TYPE_CHECKING:
7+
from core.services.document_service import DocumentService
8+
9+
10+
async def get_redis_pool(request: Request) -> arq.ArqRedis:
11+
if not hasattr(request.app.state, "redis_pool") or request.app.state.redis_pool is None:
12+
raise RuntimeError("Redis pool not initialized or not available on app.state")
13+
return request.app.state.redis_pool
14+
15+
16+
async def get_document_service(request: Request) -> "DocumentService":
17+
if not hasattr(request.app.state, "document_service") or request.app.state.document_service is None:
18+
raise RuntimeError("Document service not initialized or not available on app.state")
19+
return request.app.state.document_service

core/models/documents.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class Document(BaseModel):
3737
filename: Optional[str] = None
3838
metadata: Dict[str, Any] = Field(default_factory=dict)
3939
"""user-defined metadata"""
40-
storage_info: Dict[str, str] = Field(default_factory=dict)
40+
storage_info: Dict[str, Any] = Field(default_factory=dict)
4141
"""Legacy field for backwards compatibility - for single file storage"""
4242
storage_files: List[StorageFileInfo] = Field(default_factory=list)
4343
"""List of files associated with this document"""
@@ -57,6 +57,13 @@ class Document(BaseModel):
5757
access_control: Dict[str, List[str]] = Field(default_factory=lambda: {"readers": [], "writers": [], "admins": []})
5858
chunk_ids: List[str] = Field(default_factory=list)
5959

60+
# Ensure storage_info values are strings to maintain backward compatibility
61+
@field_validator("storage_info", mode="before")
62+
def _coerce_storage_info_values(cls, v):
63+
if isinstance(v, dict):
64+
return {k: str(val) if val is not None else "" for k, val in v.items()}
65+
return v
66+
6067
def __hash__(self):
6168
return hash(self.external_id)
6269

0 commit comments

Comments
 (0)