|
1 | 1 | import re |
2 | | -from typing import Union |
| 2 | +from typing import Union, Dict, Tuple |
3 | 3 |
|
4 | 4 | from sys import stderr |
5 | 5 |
|
6 | 6 | from openmaptiles.tileset import Tileset, Layer |
7 | 7 |
|
8 | 8 |
|
9 | | -def collect_sql(tileset_filename, parallel=False, nodata=False): |
10 | | - """If parallel is True, returns a sql value that must be executed first, |
11 | | - and a lists of sql values that can be ran in parallel. |
| 9 | +def collect_sql(tileset_filename, parallel=False, nodata=False |
| 10 | + ) -> Union[str, Tuple[str, Dict[str, str], str]]: |
| 11 | + """If parallel is True, returns a sql value that must be executed first, last, |
| 12 | + and a dict of names -> sql code that can be ran in parallel. |
12 | 13 | If parallel is False, returns a single sql string. |
13 | 14 | nodata=True replaces all "/* DELAY_MATERIALIZED_VIEW_CREATION */" |
14 | 15 | with the "WITH NO DATA" SQL.""" |
15 | | - tileset = Tileset.parse(tileset_filename) |
| 16 | + tileset = Tileset(tileset_filename) |
| 17 | + |
| 18 | + run_first = "-- This SQL code should be executed first\n\n" + \ |
| 19 | + get_slice_language_tags(tileset.languages) |
| 20 | + # at this point we don't have any SQL to run at the end |
| 21 | + run_last = "-- This SQL code should be executed last\n" |
| 22 | + |
| 23 | + # resolved is a map of layer ID to some ID in results. |
| 24 | + # the ID in results could be the same as layer ID, or it could be a tuple of IDs |
| 25 | + resolved = {} |
| 26 | + # results is an ID -> SQL content map |
| 27 | + results = {} |
| 28 | + unresolved = tileset.layers_by_id.copy() |
| 29 | + last_count = -1 |
| 30 | + # safety to prevent infinite loop, even though it is also checked in tileset |
| 31 | + while len(resolved) > last_count: |
| 32 | + last_count = len(resolved) |
| 33 | + for lid, layer in list(unresolved.items()): |
| 34 | + if all((v in resolved for v in layer.requires)): |
| 35 | + # All requirements have been resolved. |
| 36 | + resolved[lid] = lid |
| 37 | + results[lid] = layer_to_sql(layer, nodata) |
| 38 | + del unresolved[lid] |
| 39 | + |
| 40 | + if layer.requires: |
| 41 | + # If there are more than one requirement, merge them first, |
| 42 | + # e.g. if there are layers A, B, and C; and C requires A & B, |
| 43 | + # first concatenate A and B, and then append C to them. |
| 44 | + # Make sure the same code is not merged multiple times |
| 45 | + mix = list(layer.requires) + [lid] |
| 46 | + lid1 = mix[0] |
| 47 | + for idx in range(1, len(mix)): |
| 48 | + lid2 = mix[idx] |
| 49 | + res_id1 = resolved[lid1] |
| 50 | + res_id2 = resolved[lid2] |
| 51 | + if res_id1 == res_id2: |
| 52 | + continue |
| 53 | + merged_id = res_id1 + "__" + res_id2 |
| 54 | + if merged_id in results: |
| 55 | + raise ValueError(f"Naming collision - {merged_id} exists") |
| 56 | + # NOTE: merging will move entity to the end of the list |
| 57 | + results[merged_id] = results[res_id1] + "\n" + results[res_id2] |
| 58 | + del results[res_id1] |
| 59 | + del results[res_id2] |
| 60 | + # Update resolved IDs to point to the merged result |
| 61 | + for k, v in resolved.items(): |
| 62 | + if v == res_id1 or v == res_id2: |
| 63 | + resolved[k] = merged_id |
| 64 | + if unresolved: |
| 65 | + raise ValueError(f"Circular dependency found in layer requirements: " + |
| 66 | + ', '.join(unresolved.keys())) |
| 67 | + |
| 68 | + if not parallel: |
| 69 | + sql = '\n'.join(results.values()) |
| 70 | + return f"{run_first}\n{sql}\n{run_last}" |
| 71 | + else: |
| 72 | + return run_first, results, run_last |
16 | 73 |
|
17 | | - run_first = get_slice_language_tags(tileset.languages) |
18 | | - run_last = '' # at this point we don't have any SQL to run at the end |
19 | 74 |
|
20 | | - parallel_sql = [] |
21 | | - for layer in tileset.layers: |
22 | | - schemas = '\n\n'.join((to_sql(v, layer, nodata) for v in layer.schemas)) |
23 | | - parallel_sql.append(f"""\ |
| 75 | +def layer_to_sql(layer: Layer, nodata: bool): |
| 76 | + schemas = '\n\n'.join((to_sql(v, layer, nodata) for v in layer.schemas)) |
| 77 | + sql = f"""\ |
24 | 78 | DO $$ BEGIN RAISE NOTICE 'Processing layer {layer.id}'; END$$; |
25 | 79 |
|
26 | 80 | {schemas} |
27 | 81 |
|
28 | 82 | DO $$ BEGIN RAISE NOTICE 'Finished layer {layer.id}'; END$$; |
29 | | -""") |
30 | | - |
31 | | - if parallel: |
32 | | - return run_first, parallel_sql, run_last |
33 | | - else: |
34 | | - return run_first + '\n'.join(parallel_sql) + run_last |
| 83 | +""" |
| 84 | + return sql.strip() + "\n" |
35 | 85 |
|
36 | 86 |
|
37 | 87 | def get_slice_language_tags(languages): |
@@ -143,7 +193,7 @@ def sql_value(value): |
143 | 193 | return "E'" + value.replace('\\', '\\\\').replace("'", "\\'") + "'" |
144 | 194 |
|
145 | 195 |
|
146 | | -def to_sql(sql, layer, nodata): |
| 196 | +def to_sql(sql: str, layer: Layer, nodata: bool): |
147 | 197 | """Clean up SQL, and perform any needed code injections""" |
148 | 198 | sql = sql.strip() |
149 | 199 |
|
|
0 commit comments