Skip to content

Commit 573e3b2

Browse files
authored
Merge pull request #7031 from hotosm/feat/update-users-concurrently
Update user stats in parallel
2 parents 3212b6e + 4cd7704 commit 573e3b2

File tree

1 file changed

+171
-22
lines changed

1 file changed

+171
-22
lines changed
Lines changed: 171 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,161 @@
11
import asyncio
22
import sys
33
import os
4+
import logging
5+
import argparse
6+
from types import SimpleNamespace
7+
from typing import List
8+
49
from backend.db import db_connection
5-
from backend.models.postgis.mapping_badge import MappingBadge
6-
from backend.models.postgis.mapping_level import MappingLevel
710
from backend.services.users.user_service import UserService
8-
from backend.models.postgis.user import User, UserNextLevel
9-
import logging
11+
from backend.models.postgis.user import User
12+
13+
import httpx
1014

1115
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
1216

1317
logging.basicConfig(level=logging.INFO)
1418
logger = logging.getLogger(__name__)
1519

20+
# Defaults come from env or fall back to sensible values
21+
DEFAULT_MAX_CONCURRENT_WORKERS = int(os.getenv("MAX_CONCURRENT_WORKERS", "50"))
22+
DEFAULT_TASK_BATCH_SIZE = int(os.getenv("TASK_BATCH_SIZE", "5000"))
23+
PROGRESS_PRINT_EVERY = int(os.getenv("PROGRESS_PRINT_EVERY", "1000"))
24+
25+
RETRIES = 2
26+
RETRY_DELAY = 1.0
1627

17-
async def main():
28+
29+
async def process_user(
30+
user_record,
31+
failed_users,
32+
failed_lock,
33+
users_updated_counter,
34+
counter_lock,
35+
semaphore,
36+
):
37+
"""
38+
Process a single user using its own DB connection from the pool.
39+
Calls UserService.check_and_update_mapper_level(user_id, conn).
40+
Retries on transient network/HTTP errors with simple fixed delay.
41+
"""
42+
await semaphore.acquire()
43+
try:
44+
async with db_connection.database.connection() as conn:
45+
attempt = 0
46+
while True:
47+
attempt += 1
48+
try:
49+
await UserService.check_and_update_mapper_level(
50+
user_record.id, conn
51+
)
52+
break
53+
except (
54+
httpx.ReadTimeout,
55+
httpx.TransportError,
56+
httpx.HTTPError,
57+
) as exc:
58+
if attempt >= RETRIES:
59+
async with failed_lock:
60+
failed_users.append(user_record.id)
61+
logger.exception(
62+
"Failed to update stats/mapper level for user %s after %d attempts",
63+
user_record.id,
64+
attempt,
65+
)
66+
break
67+
else:
68+
logger.warning(
69+
"Transient error for user %s (attempt %d/%d): %s — retrying in %.1fs",
70+
user_record.id,
71+
attempt,
72+
RETRIES,
73+
exc.__class__.__name__,
74+
RETRY_DELAY,
75+
)
76+
await asyncio.sleep(RETRY_DELAY)
77+
continue
78+
except Exception:
79+
async with failed_lock:
80+
failed_users.append(user_record.id)
81+
logger.exception(
82+
"Failed to update stats/mapper level for user %s",
83+
user_record.id,
84+
)
85+
break
86+
finally:
87+
async with counter_lock:
88+
users_updated_counter[0] += 1
89+
updated = users_updated_counter[0]
90+
if updated % PROGRESS_PRINT_EVERY == 0:
91+
logger.info(f"{updated} users updated")
92+
semaphore.release()
93+
94+
95+
async def _fetch_users_only_missing(conn) -> List[SimpleNamespace]:
96+
"""
97+
Return lightweight objects (id, username) for users missing user_stats entries.
98+
"""
99+
users = await conn.fetch_all(
100+
query="""
101+
SELECT u.id, u.username
102+
FROM users u
103+
WHERE NOT EXISTS (
104+
SELECT 1 FROM user_stats s WHERE s.user_id = u.id
105+
)
106+
ORDER BY u.id
107+
"""
108+
)
109+
return users
110+
111+
112+
async def main(only_missing: bool, workers: int, batch_size: int):
18113
try:
19114
logger.info("Connecting to database...")
20115
await db_connection.connect()
21-
db = db_connection.database
22116

23117
logger.info("Started updating mapper levels...")
24-
users = await User.get_all_users_not_paginated(db)
118+
logger.info(
119+
"Using %d concurrent workers, task batch size %d", workers, batch_size
120+
)
121+
122+
async with db_connection.database.connection() as conn:
123+
if only_missing:
124+
users = await _fetch_users_only_missing(conn)
125+
else:
126+
users = await User.get_all_users_not_paginated(conn)
127+
25128
total_users = len(users)
129+
logger.info("Fetched %d users to process", total_users)
130+
26131
failed_users = []
27-
users_updated = 0
28-
29-
for user in users:
30-
try:
31-
await UserService.check_and_update_mapper_level(user.id, db)
32-
except Exception:
33-
failed_users.append(user.id)
34-
logger.exception(
35-
"Failed to update stats/mapper level for user %s",
36-
user.id,
132+
failed_lock = asyncio.Lock()
133+
users_updated_counter = [0]
134+
counter_lock = asyncio.Lock()
135+
semaphore = asyncio.Semaphore(workers)
136+
137+
for start in range(0, total_users, batch_size):
138+
end = min(start + batch_size, total_users)
139+
batch = users[start:end]
140+
logger.info("Scheduling batch %d..%d (size=%d)", start + 1, end, len(batch))
141+
142+
tasks = [
143+
asyncio.create_task(
144+
process_user(
145+
user_record,
146+
failed_users,
147+
failed_lock,
148+
users_updated_counter,
149+
counter_lock,
150+
semaphore,
151+
)
37152
)
38-
continue
153+
for user_record in batch
154+
]
39155

40-
users_updated += 1
41-
if users_updated % 1000 == 0:
42-
logger.info(f"{users_updated} users updated of {total_users}")
156+
await asyncio.gather(*tasks)
43157

158+
users_updated = users_updated_counter[0]
44159
logger.info(f"Finished. Updated {users_updated} user mapper levels.")
45160
logger.info(f"Failed stats update for these users: {failed_users}.")
46161

@@ -56,4 +171,38 @@ async def main():
56171

57172

58173
if __name__ == "__main__":
59-
asyncio.run(main())
174+
parser = argparse.ArgumentParser(description="Refresh mapper levels / user stats")
175+
parser.add_argument(
176+
"--only-missing",
177+
action="store_true",
178+
help="Process only users that do not have an entry in user_stats",
179+
)
180+
parser.add_argument(
181+
"--workers",
182+
"-w",
183+
type=int,
184+
default=DEFAULT_MAX_CONCURRENT_WORKERS,
185+
help=f"Number of concurrent workers (default {DEFAULT_MAX_CONCURRENT_WORKERS})",
186+
)
187+
parser.add_argument(
188+
"--batch-size",
189+
"-b",
190+
type=int,
191+
default=DEFAULT_TASK_BATCH_SIZE,
192+
help=f"Number of users scheduled per batch (default {DEFAULT_TASK_BATCH_SIZE})",
193+
)
194+
args = parser.parse_args()
195+
196+
# Basic validation
197+
if args.workers <= 0:
198+
parser.error("--workers must be a positive integer")
199+
if args.batch_size <= 0:
200+
parser.error("--batch-size must be a positive integer")
201+
202+
asyncio.run(
203+
main(
204+
only_missing=args.only_missing,
205+
workers=args.workers,
206+
batch_size=args.batch_size,
207+
)
208+
)

0 commit comments

Comments
 (0)