Skip to content

Commit bd220ef

Browse files
feat: simplify backend states (#594)
Update app.py
1 parent 4c96b9a commit bd220ef

File tree

2 files changed

+114
-130
lines changed

2 files changed

+114
-130
lines changed

src/backend/app/routes/ws/reverse.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from app.deps import RedisDep, S3Dep
1313
from app.schemas.reverse import RoomFileEntry, RoomFileEvent
1414
from app.settings import settings
15+
from app.states.app import UploadProgress
1516
from app.states.room import RoomState
1617

1718
logger = logging.getLogger(__name__)
@@ -29,46 +30,66 @@ async def _stream_file_to_ws(
2930
send_lock: asyncio.Lock,
3031
) -> None:
3132
"""Fetch one file from RustFS and relay every chunk to *ws*."""
33+
upload_key = f"ws:{uuid.uuid4()}"
34+
await UploadProgress.start(
35+
upload_key=upload_key, filename=entry.filename, track_space=False
36+
)
37+
38+
success = False
3239
try:
33-
s3_response = await s3_client.get_object(
34-
Bucket=settings.RUSTFS_BUCKET_NAME,
35-
Key=entry.key,
36-
)
37-
except ClientError:
40+
try:
41+
s3_response = await s3_client.get_object(
42+
Bucket=settings.RUSTFS_BUCKET_NAME,
43+
Key=entry.key,
44+
)
45+
except ClientError:
46+
async with send_lock:
47+
await ws.send_text(
48+
json.dumps(
49+
{
50+
"type": "file_error",
51+
"key": entry.key,
52+
"detail": "Not found in storage",
53+
}
54+
)
55+
)
56+
return
57+
3858
async with send_lock:
3959
await ws.send_text(
4060
json.dumps(
4161
{
42-
"type": "file_error",
62+
"type": "file_start",
4363
"key": entry.key,
44-
"detail": "Not found in storage",
64+
"filename": entry.filename,
65+
"size": entry.size,
4566
}
4667
)
4768
)
48-
return
4969

50-
async with send_lock:
51-
await ws.send_text(
52-
json.dumps(
53-
{
54-
"type": "file_start",
55-
"key": entry.key,
56-
"filename": entry.filename,
57-
"size": entry.size,
58-
}
59-
)
60-
)
70+
body = s3_response["Body"]
71+
uploaded_bytes = 0
72+
try:
73+
async for chunk in body.iter_chunks(S3_CHUNK_SIZE):
74+
async with send_lock:
75+
await ws.send_bytes(chunk)
76+
uploaded_bytes += len(chunk)
77+
await UploadProgress.update(
78+
upload_key=upload_key, uploaded_bytes=uploaded_bytes
79+
)
80+
finally:
81+
body.close()
6182

62-
body = s3_response["Body"]
63-
try:
64-
async for chunk in body.iter_chunks(S3_CHUNK_SIZE):
65-
async with send_lock:
66-
await ws.send_bytes(chunk)
67-
finally:
68-
body.close()
83+
async with send_lock:
84+
await ws.send_text(json.dumps({"type": "file_end", "key": entry.key}))
6985

70-
async with send_lock:
71-
await ws.send_text(json.dumps({"type": "file_end", "key": entry.key}))
86+
await UploadProgress.finish(upload_key=upload_key, final_size=entry.size)
87+
success = True
88+
except (asyncio.CancelledError, Exception):
89+
raise
90+
finally:
91+
if not success:
92+
await UploadProgress.cancel(upload_key=upload_key)
7293

7394

7495
@router.websocket("/ws/reverse/rooms/{room_id}")

src/backend/app/states/app.py

Lines changed: 65 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -6,79 +6,36 @@
66
from redis.asyncio import Redis
77

88
from app.settings import settings
9-
109
from ._global import GlobalState
1110

1211

13-
class UploadProgress(GlobalState, BaseModel):
12+
class UploadProgress(BaseModel):
1413
upload_key: str
1514
filename: str
1615
uploaded_bytes: int = 0
1716
done: bool = False
1817
last_updated: datetime | None = None
18+
track_space: bool = True
1919

2020
@classmethod
21-
async def start(cls, upload_key: str, filename: str) -> "AppState":
22-
"""Register a new in-flight upload and set last_updated."""
23-
current = await AppState.get()
24-
now = datetime.now(timezone.utc)
25-
current.active_uploads.append(
26-
cls(upload_key=upload_key, filename=filename, last_updated=now)
21+
async def start(cls, upload_key: str, filename: str, track_space: bool = True):
22+
return await AppState._update_active(
23+
upload_key, filename=filename, track_space=track_space
2724
)
28-
await AppState.set(current)
29-
return current
3025

3126
@classmethod
32-
async def update(cls, upload_key: str, uploaded_bytes: int) -> "AppState":
33-
"""Update the byte count and touch last_updated for an in-flight upload."""
34-
current = await AppState.get()
35-
now = datetime.now(timezone.utc)
36-
for upload in current.active_uploads:
37-
if upload.upload_key == upload_key:
38-
# Add the delta to total_space_used so partial uploads are accounted
39-
delta = uploaded_bytes - (upload.uploaded_bytes or 0)
40-
if delta > 0:
41-
current.total_space_used += delta
42-
upload.uploaded_bytes = uploaded_bytes
43-
upload.last_updated = now
44-
break
45-
await AppState.set(current)
46-
return current
27+
async def update(cls, upload_key: str, uploaded_bytes: int):
28+
return await AppState._update_active(upload_key, uploaded_bytes=uploaded_bytes)
4729

4830
@classmethod
49-
async def finish(cls, upload_key: str, final_size: int) -> "AppState":
50-
"""Mark an upload as done, touch last_updated, and add its size to total_space_used."""
51-
current = await AppState.get()
52-
now = datetime.now(timezone.utc)
53-
for upload in current.active_uploads:
54-
if upload.upload_key == upload_key:
55-
# If we tracked partial bytes earlier, only add the remaining delta
56-
delta = final_size - (upload.uploaded_bytes or 0)
57-
if delta > 0:
58-
current.total_space_used += delta
59-
upload.done = True
60-
upload.uploaded_bytes = final_size
61-
upload.last_updated = now
62-
break
63-
await AppState.set(current)
64-
return current
31+
async def finish(cls, upload_key: str, final_size: int):
32+
return await AppState._update_active(
33+
upload_key, uploaded_bytes=final_size, done=True
34+
)
6535

6636
@classmethod
67-
async def cancel(cls, upload_key: str) -> "AppState":
68-
"""Remove a failed/disconnected upload from the active list without adjusting totals."""
69-
current = await AppState.get()
70-
# Subtract any tracked uploaded bytes from totals when cancelling
71-
remaining = []
72-
freed = 0
73-
for u in current.active_uploads:
74-
if u.upload_key == upload_key:
75-
freed += u.uploaded_bytes or 0
76-
else:
77-
remaining.append(u)
78-
current.active_uploads = remaining
79-
current.total_space_used = max(0, current.total_space_used - freed)
80-
await AppState.set(current)
81-
return current
37+
async def cancel(cls, upload_key: str):
38+
return await AppState._update_active(upload_key, remove=True)
8239

8340

8441
class AppState(GlobalState, BaseModel):
@@ -87,37 +44,12 @@ class AppState(GlobalState, BaseModel):
8744
active_uploads: list[UploadProgress] = []
8845

8946
@classmethod
90-
async def ensure_created(cls) -> None:
91-
"""Ensure a state document exists in RedisJSON, syncing `total_available_space` from Config."""
92-
from sqlmodel import select
93-
94-
from app.db import AsyncSessionLocal
95-
from app.models.config import Config
96-
97-
async with AsyncSessionLocal() as session:
98-
result = await session.exec(select(Config))
99-
config = result.first()
100-
total_available_space = config.total_storage_limit if config else None
101-
102-
existing = await cls._json_get(settings.STATE_REDIS_KEY)
103-
if existing is None:
104-
state = cls(total_available_space=total_available_space)
105-
await cls._json_set(settings.STATE_REDIS_KEY, state.model_dump(mode="json"))
106-
else:
107-
existing["total_available_space"] = total_available_space
108-
await cls._json_set(settings.STATE_REDIS_KEY, existing)
109-
110-
@classmethod
111-
async def get(cls, redis_client: Redis | None = None) -> AppState:
112-
"""Return the full current state from RedisJSON."""
47+
async def get(cls, redis_client: Redis | None = None) -> "AppState":
11348
data = await cls._json_get(settings.STATE_REDIS_KEY, redis_client=redis_client)
114-
if data is None:
115-
return cls()
116-
return cls.model_validate(data)
49+
return cls.model_validate(data) if data else cls()
11750

11851
@classmethod
119-
async def set(cls, state: AppState, redis_client: Redis | None = None) -> None:
120-
"""Overwrite the entire state and notify all app instances."""
52+
async def set(cls, state: "AppState", redis_client: Redis | None = None) -> None:
12153
await cls._json_set(
12254
settings.STATE_REDIS_KEY,
12355
state.model_dump(mode="json"),
@@ -127,21 +59,47 @@ async def set(cls, state: AppState, redis_client: Redis | None = None) -> None:
12759
settings.STATE_CHANNEL, state.model_dump_json(), redis_client=redis_client
12860
)
12961

62+
@classmethod
63+
async def _update_active(cls, key: str, remove=False, **kwargs) -> "AppState":
64+
"""Unified internal method for all upload state changes (start, update, finish, cancel)."""
65+
s = await cls.get()
66+
now = datetime.now(timezone.utc)
67+
upload = next((u for u in s.active_uploads if u.upload_key == key), None)
68+
69+
# Update Space Tracking (if track_space is True)
70+
if upload and upload.track_space:
71+
if remove:
72+
s.total_space_used = max(0, s.total_space_used - upload.uploaded_bytes)
73+
elif "uploaded_bytes" in kwargs:
74+
s.total_space_used += kwargs["uploaded_bytes"] - upload.uploaded_bytes
75+
76+
# Update Active Uploads List
77+
if remove or (upload and kwargs.get("done") and not upload.track_space):
78+
s.active_uploads = [u for u in s.active_uploads if u.upload_key != key]
79+
elif not upload:
80+
s.active_uploads.append(
81+
UploadProgress(upload_key=key, last_updated=now, **kwargs)
82+
)
83+
else:
84+
for k, v in kwargs.items():
85+
setattr(upload, k, v)
86+
upload.last_updated = now
87+
88+
await cls.set(s)
89+
return s
90+
13091
@classmethod
13192
async def evict_files(
13293
cls, file_keys: list[str], freed_bytes: int, redis_client: Redis | None = None
13394
) -> None:
134-
135-
async def _evict(client: Redis) -> None:
136-
current = await cls.get(redis_client=client)
137-
key_set = set(file_keys)
138-
current.active_uploads = [
139-
u for u in current.active_uploads if u.upload_key not in key_set
140-
]
141-
current.total_space_used = max(0, current.total_space_used - freed_bytes)
142-
await cls.set(current, redis_client=client)
143-
144-
if redis_client is not None:
95+
async def _evict(client: Redis):
96+
s = await cls.get(redis_client=client)
97+
keys = set(file_keys)
98+
s.active_uploads = [u for u in s.active_uploads if u.upload_key not in keys]
99+
s.total_space_used = max(0, s.total_space_used - freed_bytes)
100+
await cls.set(s, redis_client=client)
101+
102+
if redis_client:
145103
await _evict(redis_client)
146104
else:
147105
async with redis.from_url(
@@ -150,10 +108,15 @@ async def _evict(client: Redis) -> None:
150108
await _evict(client)
151109

152110
@classmethod
153-
async def patch(cls, **kwargs: Any) -> AppState:
154-
"""Partially update the state, persist, and notify."""
155-
current = await cls.get()
156-
updated = current.model_copy(update=kwargs)
157-
updated = cls.model_validate(updated.model_dump())
158-
await cls.set(updated)
159-
return updated
111+
async def ensure_created(cls) -> None:
112+
from sqlmodel import select
113+
from app.db import AsyncSessionLocal
114+
from app.models.config import Config
115+
116+
async with AsyncSessionLocal() as session:
117+
config = (await session.exec(select(Config))).first()
118+
limit = config.total_storage_limit if config else None
119+
120+
s = await cls.get()
121+
s.total_available_space = limit
122+
await cls.set(s)

0 commit comments

Comments
 (0)