-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
225 lines (190 loc) · 7.66 KB
/
app.py
File metadata and controls
225 lines (190 loc) · 7.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
from fastapi import FastAPI, HTTPException, Depends, Request
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy import Column, String, DateTime, ForeignKey, JSON, text
from pydantic import BaseModel
from typing import Optional, List
import uuid
from datetime import datetime
import json
from dotenv import load_dotenv
from llm_chain import ProjectChatChain
from traceback import print_exc
from contextlib import asynccontextmanager
# Load environment variables
load_dotenv()
# Database setup
DATABASE_URL = "sqlite+aiosqlite:///./chat.db"
engine = create_async_engine(DATABASE_URL, echo=True)
AsyncSessionLocal = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
Base = declarative_base()
@asynccontextmanager
async def lifespan(app: FastAPI):
app.state.chat_chain = ProjectChatChain()
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
await engine.dispose()
# Create FastAPI app
app = FastAPI(title="IdeaGO Chat API", lifespan=lifespan)
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Database models
class ChatSession(Base):
__tablename__ = "chat_sessions"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
user_id = Column(String, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class ChatMessage(Base):
__tablename__ = "chat_messages"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
session_id = Column(String, ForeignKey("chat_sessions.id"))
role = Column(String, nullable=False) # 'user' or 'assistant'
content = Column(String, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow)
class ProjectData(Base):
__tablename__ = "project_data"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
session_id = Column(String, ForeignKey("chat_sessions.id"))
project_data = Column(JSON)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Pydantic models
class MessageCreate(BaseModel):
user_id: str
session_id: Optional[str] = None
content: str
class Message(BaseModel):
role: str
content: str
class ChatResponse(BaseModel):
session_id: str
messages: Message
project_data: Optional[dict] = None
# Database dependency
async def get_db():
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()
# Routes
@app.get("/")
async def root():
return {"status": "IdeaGO Chat API is running", "version": "1.0.0"}
@app.post("/chat", response_model=ChatResponse)
async def chat(
request: Request,
message: MessageCreate,
db: AsyncSession = Depends(get_db)
):
try:
# Get the chat chain from app state
chat_chain = request.app.state.chat_chain
# Create or get session
if not message.session_id:
session = ChatSession(user_id=message.user_id)
db.add(session)
await db.commit()
await db.refresh(session)
session_id = session.id
chat_chain.clear_memory()
else:
session_id = message.session_id
# When continuing a conversation, load previous messages into memory
query = text("SELECT role, content FROM chat_messages WHERE session_id = :session_id ORDER BY created_at ASC")
result = await db.execute(query, {"session_id": session_id})
previous_messages = result.fetchall()
# Only initialize memory if it's empty
if not chat_chain.memory.chat_memory.messages and previous_messages:
print(f"Loading {len(previous_messages)} previous messages into memory")
# Load past messages into memory
from langchain.schema import HumanMessage, AIMessage
for role, content in previous_messages:
if role == "user":
chat_chain.memory.chat_memory.add_message(HumanMessage(content=content))
elif role == "assistant":
chat_chain.memory.chat_memory.add_message(AIMessage(content=content))
# Store user message
user_message = ChatMessage(
session_id=session_id,
role="user",
content=message.content
)
db.add(user_message)
# Process message with LLM chain
result = await chat_chain.process_message(message.content, session_id)
# Store assistant message
assistant_message = ChatMessage(
session_id=session_id,
role="assistant",
content=result["response"]
)
db.add(assistant_message)
# If we have parsed data, store it
if result["is_final"] and result["parsed_data"]:
# Ensure proper structure for talents
data = result["parsed_data"]
# Convert "talent" to "talents" array if needed
if "talent" in data and "talents" not in data:
data["talents"] = [data["talent"]]
del data["talent"]
elif not isinstance(data.get("talents", []), list):
data["talents"] = [data["talents"]]
# Store the updated data
project_data = ProjectData(
session_id=session_id,
project_data=data
)
db.add(project_data)
await db.commit()
# Get project data if it exists
query = text("SELECT project_data FROM project_data WHERE session_id = :session_id ORDER BY created_at DESC LIMIT 1")
result_project = await db.execute(query, {"session_id": session_id})
project_data_row = result_project.first()
# Parse project data from JSON string if it exists
project_data = None
if project_data_row and project_data_row[0]:
try:
if isinstance(project_data_row[0], str):
project_data = json.loads(project_data_row[0])
else:
project_data = project_data_row[0]
# Ensure proper structure for talents
if "talent" in project_data and "talents" not in project_data:
project_data["talents"] = [project_data["talent"]]
del project_data["talent"]
elif not isinstance(project_data.get("talents", []), list):
project_data["talents"] = [project_data["talents"]]
except json.JSONDecodeError:
print_exc()
project_data = None
return ChatResponse(
session_id=session_id,
messages=Message(
role="assistant",
content=result["response"]
),
project_data=project_data
)
except Exception as e:
print_exc()
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app:app",
host="0.0.0.0",
port=8000,
reload=True,
reload_includes=["*.py"],
reload_excludes=["__pycache__/*", ".*", "*.pyc"]
)