Skip to content

Commit 3ac64a4

Browse files
committed
feat(collab): basic bidirectional comms between users
1 parent 2b3c42f commit 3ac64a4

File tree

13 files changed

+300
-0
lines changed

13 files changed

+300
-0
lines changed

.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,8 @@
1+
.venv/
2+
3+
.pytest_cache/
14
**/__pycache__/
5+
6+
**/node_modules/
7+
.env
8+
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
## To run locally:
2+
`uvicorn main:app --reload`
3+
## Planning
4+
5+
## APIs to implement:
6+
7+
From Matching (REST)
8+
[ ] /sessions
9+
-
10+
11+
From UI (Websocket)
12+
[ ] /
13+
! make session_id non-determined from user ids (for security purpose; hacker cannot hijack session solely from user data)
File renamed without changes.
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from typing import Dict
2+
from fastapi import WebSocket, WebSocketDisconnect
3+
import uuid
4+
import asyncio
5+
from app.schemas.messages import ConnectMessage, DisconnectMessage, DisplayMessage
6+
7+
class ConnectionManager:
8+
def __init__(self):
9+
self.active_connections: Dict[str, Dict[str, asyncio.Queue]] = {}
10+
self.condition = asyncio.Condition()
11+
12+
def init_session(self, user_ids: list[str]) -> str:
13+
if len(user_ids) != 2:
14+
raise ValueError("user_ids must contain exactly two user IDs.")
15+
session_id = self.generate_uuid(user_ids[0], user_ids[1])
16+
17+
self.active_connections[session_id] = {}
18+
for user_id in user_ids:
19+
self.active_connections[session_id][user_id] = None
20+
return session_id
21+
22+
def _get_collaborator_q(self, session_id, user_id):
23+
session_data = self.active_connections[session_id]
24+
id = next((uid for uid in session_data if uid != user_id), None)
25+
if id is None:
26+
return None
27+
return session_data[id]
28+
29+
async def on_connect(self, session_id: str, user_id: str, out_q: asyncio.Queue):
30+
if session_id not in self.active_connections:
31+
raise ValueError("Session not initialized")
32+
33+
if user_id not in self.active_connections[session_id]:
34+
raise ValueError("user_id not available in this session")
35+
36+
self.active_connections[session_id][user_id] = out_q
37+
38+
collaborator_q = self._get_collaborator_q(session_id, user_id)
39+
40+
if collaborator_q is None:
41+
return None
42+
43+
await collaborator_q.put(ConnectMessage())
44+
45+
async def on_disconnect(self, session_id: str, user_id: str):
46+
if session_id not in self.active_connections:
47+
raise ValueError("Session not initialized")
48+
49+
if user_id not in self.active_connections[session_id]:
50+
raise ValueError("user_id not available in this session")
51+
52+
self.active_connections[session_id][user_id] = None
53+
54+
collaborator_q = self._get_collaborator_q(session_id, user_id)
55+
56+
if collaborator_q is None:
57+
return None
58+
59+
await collaborator_q.put(DisconnectMessage())
60+
61+
async def on_message(self, session_id: str, user_id: str, msg: str):
62+
collaborator_q = self._get_collaborator_q(session_id, user_id)
63+
64+
if collaborator_q is None:
65+
return None
66+
67+
await collaborator_q.put(DisplayMessage(msg=msg))
68+
69+
def generate_uuid(self, user1: str, user2: str) -> str:
70+
namespace = uuid.NAMESPACE_DNS # or uuid.NAMESPACE_URL, etc.
71+
seed = ''.join(sorted([user1, user2]))
72+
return str(uuid.uuid5(namespace, seed))
73+
74+
# async def connect(self, session_id: str, user_id: str, ws: WebSocket):
75+
# if session_id not in self.active_connections:
76+
# raise ValueError("Session not initialized.")
77+
78+
# if user_id not in self.active_connections[session_id]:
79+
# raise ValueError("user_id not recognized in this session.")
80+
81+
# await ws.accept()
82+
# self.active_connections[session_id][user_id] = ws
83+
84+
# # Notify any waiting coroutines that a user has connected
85+
# async with self.condition:
86+
# self.condition.notify_all()
87+
88+
# return None
89+
90+
# async def await_collaborator_ws(self, session_id: str, user_id: str) -> WebSocket | None:
91+
# if session_id not in self.active_connections:
92+
# raise ValueError("Session not initialized.")
93+
94+
# session_data = self.active_connections[session_id]
95+
# collaborator_id = next((uid for uid in session_data if uid != user_id), None)
96+
97+
# user_ws = session_data[user_id]
98+
# collaborator_ws = session_data[collaborator_id]
99+
100+
# try:
101+
# if collaborator_ws is None:
102+
# async with self.condition:
103+
# await user_ws.send_text(f"Waiting for collaborator {collaborator_id} to connect...")
104+
# await self.condition.wait_for(lambda: session_data[collaborator_id] is not None)
105+
# await user_ws.send_text(f"Collaborator {collaborator_id} connected!")
106+
# collaborator_ws = session_data[collaborator_id]
107+
# except Exception as e:
108+
# print(f"Error while waiting for collaborator: {e}")
109+
# return None
110+
111+
# return session_data[collaborator_id]
112+
113+
# async def disconnect(self, session_id: str, user_id: str):
114+
# if session_id not in self.active_connections:
115+
# return
116+
117+
# session_data = self.active_connections[session_id]
118+
119+
# ws = session_data[user_id]
120+
# session_data[user_id] = None
121+
122+
# # Notify any waiting coroutines that a user has disconnected
123+
# async with self.condition:
124+
# self.condition.notify_all()
125+
126+
# if ws:
127+
# try:
128+
# await ws.close()
129+
# except Exception:
130+
# pass
131+
132+
# def get_active_connections(self, session_id: str):
133+
# return self.active_connections.get(session_id, {})
134+
135+
136+
connection_manager = ConnectionManager()
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from enum import Enum
2+
3+
class UserState(Enum):
4+
AWAIT_CONNECT = "awaiting_connection"
5+
AWAIT_POLLING = "awaiting_polling"
6+
7+
8+
9+

services/collaboration-service/app/main.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from fastapi import FastAPI
2+
from app.ws import sessions as ws_sessions
3+
from app.rest import sessions
24

35
app = FastAPI()
6+
app.include_router(sessions.router, prefix="/sessions")
7+
app.include_router(ws_sessions.router, prefix="/ws/sessions")
48

59
## API for testing connection
610
@app.get("/ping")

services/collaboration-service/app/rest/__init__.py

Whitespace-only changes.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
2+
from fastapi import APIRouter
3+
from app.schemas.collaboration import CreateSessionRequest, CreateSessionResponse
4+
from app.core.connection_manager import connection_manager
5+
6+
router = APIRouter()
7+
8+
@router.post("/", response_model=CreateSessionResponse)
9+
def create_session(request: CreateSessionRequest):
10+
# print(f"Creating session for users: {request.user_ids}")
11+
user_ids = request.user_ids
12+
session_id = connection_manager.init_session(user_ids)
13+
return CreateSessionResponse(session_id=session_id)
14+
15+
16+

services/collaboration-service/app/schemas/__init__.py

Whitespace-only changes.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
2+
from pydantic import BaseModel, conlist
3+
from typing import List
4+
from datetime import datetime
5+
6+
## TODO: Finalize CreateSessionRequest and CreateSessionResponse
7+
class CreateSessionRequest(BaseModel):
8+
user_ids: List[str] = conlist(str, min_length=2, max_length=2)
9+
# created_at: datetime = datetime.now()
10+
## question: Question
11+
12+
class CreateSessionResponse(BaseModel):
13+
session_id: str
14+
# created_at: datetime

0 commit comments

Comments
 (0)