|
21 | 21 | except ImportError:
|
22 | 22 | from UserList import UserList
|
23 | 23 |
|
| 24 | +try: # noqa: SIM105 |
| 25 | + range = xrange # noqa: A001 |
| 26 | +except NameError: |
| 27 | + pass |
| 28 | + |
24 | 29 | import psycopg2
|
25 | 30 | from psycopg2 import sql
|
26 | 31 |
|
|
36 | 41 | _logger = logging.getLogger(__name__)
|
37 | 42 |
|
38 | 43 | ON_DELETE_ACTIONS = frozenset(("SET NULL", "CASCADE", "RESTRICT", "NO ACTION", "SET DEFAULT"))
|
| 44 | +MAX_BUCKETS = int(os.getenv("MAX_BUCKETS", "150000")) |
39 | 45 |
|
40 | 46 |
|
41 | 47 | class PGRegexp(str):
|
@@ -196,27 +202,57 @@ def explode_query_range(cr, query, table, alias=None, bucket_size=10000, prefix=
|
196 | 202 |
|
197 | 203 | alias = alias or table
|
198 | 204 |
|
199 |
| - cr.execute("SELECT min(id), max(id) FROM {}".format(table)) |
| 205 | + if "{parallel_filter}" not in query: |
| 206 | + sep_kw = " AND " if re.search(r"\sWHERE\s", query, re.M | re.I) else " WHERE " |
| 207 | + query += sep_kw + "{parallel_filter}" |
| 208 | + |
| 209 | + cr.execute(format_query(cr, "SELECT min(id), max(id) FROM {}", table)) |
200 | 210 | min_id, max_id = cr.fetchone()
|
201 | 211 | if min_id is None:
|
202 | 212 | return [] # empty table
|
| 213 | + count = (max_id + 1 - min_id) // bucket_size |
| 214 | + if count > MAX_BUCKETS: |
| 215 | + _logger.getChild("explode_query_range").warning( |
| 216 | + "High number of queries generated (%s); switching to a precise bucketing strategy", count |
| 217 | + ) |
| 218 | + cr.execute( |
| 219 | + format_query( |
| 220 | + cr, |
| 221 | + """ |
| 222 | + WITH t AS ( |
| 223 | + SELECT id, |
| 224 | + mod(row_number() OVER(ORDER BY id) - 1, %s) AS g |
| 225 | + FROM {table} |
| 226 | + ORDER BY id |
| 227 | + ) SELECT array_agg(id ORDER BY id) FILTER (WHERE g=0), |
| 228 | + min(id), |
| 229 | + max(id) |
| 230 | + FROM t |
| 231 | + """, |
| 232 | + table=table, |
| 233 | + ), |
| 234 | + [bucket_size], |
| 235 | + ) |
| 236 | + ids, min_id, max_id = cr.fetchone() |
| 237 | + else: |
| 238 | + ids = list(range(min_id, max_id + 1, bucket_size)) |
203 | 239 |
|
204 |
| - if "{parallel_filter}" not in query: |
205 |
| - sep_kw = " AND " if re.search(r"\sWHERE\s", query, re.M | re.I) else " WHERE " |
206 |
| - query += sep_kw + "{parallel_filter}" |
| 240 | + assert min_id == ids[0] and max_id + 1 != ids[-1] # sanity checks |
| 241 | + ids.append(max_id + 1) # ensure last bucket covers whole range |
| 242 | + # `ids` holds a list of values marking the interval boundaries for all buckets |
207 | 243 |
|
208 |
| - if ((max_id - min_id + 1) * 0.9) <= bucket_size: |
209 |
| - # If there is less than `bucket_size` records (with a 10% tolerance), no need to explode the query. |
210 |
| - # Force usage of `prefix` in the query to validate it correctness. |
211 |
| - # If we don't the query may only be valid if there is no split. It avoid scripts to pass the CI but fail in production. |
| 244 | + if (max_id - min_id + 1) <= 1.1 * bucket_size or (len(ids) == 3 and ids[2] - ids[1] <= 0.1 * bucket_size): |
| 245 | + # If we return one query `parallel_execute` skip spawning new threads. Thus we return only one query if we have |
| 246 | + # only two buckets and the second would have at most 10% of bucket_size records. |
| 247 | + # Still, since the query may only be valid if there is no split, we force the usage of `prefix` in the query to |
| 248 | + # validate its correctness and avoid scripts that pass the CI but fail in production. |
212 | 249 | parallel_filter = "{alias}.id IS NOT NULL".format(alias=alias)
|
213 | 250 | return [query.format(parallel_filter=parallel_filter)]
|
214 | 251 |
|
215 | 252 | parallel_filter = "{alias}.id BETWEEN %(lower-bound)s AND %(upper-bound)s".format(alias=alias)
|
216 | 253 | query = query.replace("%", "%%").format(parallel_filter=parallel_filter)
|
217 | 254 | return [
|
218 |
| - cr.mogrify(query, {"lower-bound": index, "upper-bound": index + bucket_size - 1}).decode() |
219 |
| - for index in range(min_id, max_id, bucket_size) |
| 255 | + cr.mogrify(query, {"lower-bound": ids[i], "upper-bound": ids[i + 1] - 1}).decode() for i in range(len(ids) - 1) |
220 | 256 | ]
|
221 | 257 |
|
222 | 258 |
|
|
0 commit comments