Skip to content

Commit e8a15ad

Browse files
committed
Implement batch processing for member upserts
Replaced per-member database upserts with batched processing to improve efficiency. Added periodic flushing of batched member data and a system to handle batch limits. Ensured proper cleanup of tasks during shutdown to avoid dangling operations.
1 parent 8728c92 commit e8a15ad

File tree

1 file changed

+111
-28
lines changed

1 file changed

+111
-28
lines changed

discord/state.py

Lines changed: 111 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,22 @@ def __init__(
276276
parsers[attr[6:].upper()] = func
277277

278278
self.clear()
279+
self._member_batch: List[Dict[str, Any]] = []
280+
self._member_batch_lock = asyncio.Lock()
281+
self._batch_size = 250
282+
self._batch_flush_interval = 3600 # 1 hour between auto flushes
283+
self._flush_task = self.loop.create_task(self._periodic_flush())
284+
285+
async def _periodic_flush(self) -> None:
286+
"""Periodically flush member batches to the database."""
287+
try:
288+
while True:
289+
await asyncio.sleep(self._batch_flush_interval)
290+
await self._flush_member_batch()
291+
except asyncio.CancelledError:
292+
# Flush remaining members if needed before exit.
293+
await self._flush_member_batch()
294+
279295

280296
# For some reason Discord still sends emoji/sticker data in payloads
281297
# This makes it hard to actually swap out the appropriate store methods
@@ -291,6 +307,14 @@ async def close(self) -> None:
291307
except Exception:
292308
# if an error happens during disconnects, disregard it.
293309
pass
310+
if hasattr(self, "_flush_task"):
311+
self._flush_task.cancel()
312+
try:
313+
await self._flush_task
314+
except asyncio.CancelledError:
315+
pass
316+
except Exception:
317+
pass
294318

295319
if self._translator:
296320
await self._translator.unload()
@@ -309,6 +333,80 @@ def cluster_id(self) -> int | None:
309333

310334
return int(found)
311335

336+
async def _flush_member_batch(self) -> None:
337+
"""
338+
Flush the queued member upserts in one batched query.
339+
"""
340+
async with self._member_batch_lock:
341+
if not self._member_batch:
342+
return
343+
# Extract the batch and reset the container.
344+
batch = self._member_batch
345+
self._member_batch = []
346+
347+
# Prepare arrays for our batched upsert.
348+
user_ids: List[int] = []
349+
members_data: List[Any] = [] # Adjust type if needed (e.g. to JSON/dict)
350+
cluster_ids: List[int] = []
351+
guild_ids: List[int] = []
352+
353+
for item in batch:
354+
member = item["member"]
355+
guild_id = item["guild_id"]
356+
user = member.get("user")
357+
if not user or "id" not in user:
358+
logging.warning("Invalid member payload during batch flush: %s", member)
359+
continue
360+
try:
361+
uid = int(user["id"])
362+
except (ValueError, TypeError):
363+
logging.warning("Invalid user id in payload: %s", user)
364+
continue
365+
366+
user_ids.append(uid)
367+
members_data.append(member)
368+
cluster_ids.append(self.cluster_id)
369+
guild_ids.append(guild_id)
370+
371+
# If after filtering there is nothing to flush, return.
372+
if not user_ids:
373+
return
374+
375+
query = """
376+
WITH upsert_user AS (
377+
INSERT INTO discord_users (user_id, data, cluster_id)
378+
SELECT uid, data, cluster_id FROM (
379+
SELECT unnest($1::bigint[]) AS uid,
380+
unnest($2::jsonb[]) AS data,
381+
unnest($3::int[]) AS cluster_id
382+
) AS vals
383+
ON CONFLICT (user_id) DO UPDATE
384+
SET data = EXCLUDED.data,
385+
cluster_id = EXCLUDED.cluster_id
386+
)
387+
INSERT INTO discord_members (guild_id, user_id, data, cluster_id)
388+
SELECT guild_id, uid, data, cluster_id FROM (
389+
SELECT unnest($4::int[]) AS guild_id,
390+
unnest($1::bigint[]) AS uid,
391+
unnest($2::jsonb[]) AS data,
392+
unnest($3::int[]) AS cluster_id
393+
) AS vals
394+
ON CONFLICT (guild_id, user_id) DO UPDATE
395+
SET data = EXCLUDED.data,
396+
cluster_id = EXCLUDED.cluster_id;
397+
"""
398+
try:
399+
await self.database.execute(
400+
query,
401+
user_ids,
402+
members_data,
403+
cluster_ids,
404+
guild_ids,
405+
)
406+
except Exception as e:
407+
logging.error("Failed to store %d member(s): %s", len(user_ids), e)
408+
409+
312410
async def user_to_db(self, user: Dict[str, Any]) -> None:
313411
"""
314412
Upsert a user into discord_users using asyncpg's implicit statement caching.
@@ -343,42 +441,27 @@ async def user_to_db(self, user: Dict[str, Any]) -> None:
343441

344442
async def member_to_db(self, guild_id: int, member: Dict[str, Any]) -> None:
345443
"""
346-
Upsert both user and member in one CTE to halve round trips.
444+
Queue a member upsert to later batch-process both user and member data.
347445
"""
348-
if not getattr(self, 'database', None):
446+
if not getattr(self, "database", None):
349447
return
350448

351-
user = member.get('user')
352-
if not user or 'id' not in user:
449+
# Validate member payload.
450+
user = member.get("user")
451+
if not user or "id" not in user:
353452
logging.warning("Invalid member payload: %s", member)
354453
return
355454

356-
cluster_id = self.cluster_id
357-
if cluster_id is None:
455+
if self.cluster_id is None:
358456
return
359457

360-
uid = int(user['id'])
361-
try:
362-
await self.database.execute(
363-
"""
364-
WITH upsert_user AS (
365-
INSERT INTO discord_users (user_id, data, cluster_id)
366-
VALUES ($1, $3, $4)
367-
ON CONFLICT (user_id) DO UPDATE
368-
SET data = EXCLUDED.data, cluster_id = EXCLUDED.cluster_id
369-
)
370-
INSERT INTO discord_members (guild_id, user_id, data, cluster_id)
371-
VALUES ($2, $1, $3, $4)
372-
ON CONFLICT (guild_id, user_id) DO UPDATE
373-
SET data = EXCLUDED.data, cluster_id = EXCLUDED.cluster_id
374-
""",
375-
uid,
376-
guild_id,
377-
member,
378-
cluster_id
379-
)
380-
except Exception as e:
381-
logging.error("Failed to store member %s in guild %s: %s", uid, guild_id, e)
458+
async with self._member_batch_lock:
459+
# Attach guild_id to member data for use during flush.
460+
# We store a tuple (guild_id, member).
461+
self._member_batch.append({"guild_id": guild_id, "member": member})
462+
# If the batch already reached our size threshold, flush immediately.
463+
if len(self._member_batch) >= self._batch_size:
464+
await self._flush_member_batch()
382465

383466
async def remove_user_from_db(self, user_id: int) -> None:
384467
"""

0 commit comments

Comments
 (0)