Skip to content

Commit 6cf66b0

Browse files
committed
Optimized memory.
1 parent f12d14c commit 6cf66b0

File tree

1 file changed

+112
-104
lines changed

1 file changed

+112
-104
lines changed

functions-python/tasks_executor/src/tasks/pmtiles_builder/build_pmtiles.py

Lines changed: 112 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import json
2222
import logging
2323
import os
24-
import pickle
2524
import shutil
2625
import subprocess
2726
from logging import DEBUG
@@ -67,6 +66,15 @@ def __init__(
6766
self.bucket_name = os.getenv("DATASETS_BUCKET_NAME")
6867
self.logger = get_logger(PmtilesBuilder.__name__, dataset_stable_id)
6968

69+
self.stop_times_index = {}
70+
self.stop_times_by_trip = None
71+
72+
self.stop_times_file = f"{local_dir}/stop_times.txt"
73+
self.shapes_file = f"{local_dir}/shapes.txt"
74+
self.trips_file = f"{local_dir}/trips.txt"
75+
self.routes_file = f"{local_dir}/routes.txt"
76+
self.stops_file = f"{local_dir}/stops.txt"
77+
7078
@staticmethod
7179
def _get_parameters(payload):
7280
"""
@@ -103,8 +111,6 @@ def build_pmtiles(self) -> dict:
103111

104112
self._download_files_from_gcs(unzipped_files_path)
105113

106-
self._create_shapes_index()
107-
108114
self._create_routes_geojson()
109115

110116
self._run_tippecanoe("routes-output.geojson", "routes.pmtiles")
@@ -215,15 +221,24 @@ def _upload_files_to_gcs(self, file_to_upload):
215221
except Exception as e:
216222
raise Exception(f"Failed to upload files to GCS: {e}") from e
217223

218-
def _create_shapes_index(self):
224+
def _create_shapes_index(self) -> dict:
225+
"""
226+
Create an index for shapes.txt file to quickly access shape points by shape_id.
227+
We create the index to save memory. With the index, we keep a list of positions in the file for each shape.
228+
If instead we read the whole file into memory, we would need 2 floats for the longitude and latitude plus an
229+
int for the sequence number for each point.
230+
The largest number of shapes we have currently in a dataset is 37 millions.
231+
This means about 900 MB if we have the index, and 1.6 GB if read the coordinates in memory.
232+
Returns:
233+
A dictionary with key shaped_id and values a list of positions in the shapes.txt file.
234+
"""
219235
self.logger.info("Creating shapes index")
236+
shapes_index = {}
220237
try:
221-
index = {}
222-
shapes = f"{local_dir}/shapes.txt"
223-
outfile = f"{local_dir}/shapes_index.pkl"
224-
with open(shapes, "r", encoding="utf-8", newline="") as f:
238+
with open(self.shapes_file, "r", encoding="utf-8", newline="") as f:
225239
header = f.readline()
226240
columns = next(csv.reader([header]))
241+
shapes_index["columns"] = columns
227242
count = 0
228243
while True:
229244
pos = f.tell()
@@ -232,16 +247,15 @@ def _create_shapes_index(self):
232247
break
233248
row = dict(zip(columns, next(csv.reader([line]))))
234249
sid = row["shape_id"]
235-
index.setdefault(sid, []).append(pos)
250+
shapes_index.setdefault(sid, []).append(pos)
236251
count += 1
237252
if count % 1000000 == 0:
238253
self.logger.debug("Indexed %d lines so far...", count)
239254
self.logger.debug("Total indexed lines: %d", count)
240-
self.logger.debug("Total unique shape_ids: %d", len(index))
241-
with open(outfile, "wb") as idxf:
242-
pickle.dump(index, idxf)
255+
self.logger.debug("Total unique shape_ids: %d", len(shapes_index))
243256
except Exception as e:
244257
raise Exception(f"Failed to create shapes index: {e}") from e
258+
return shapes_index
245259

246260
def _read_csv(self, filename):
247261
try:
@@ -255,8 +269,7 @@ def _get_shape_points(self, shape_id, index):
255269
self.logger.debug("Getting shape points for shape_id %s", shape_id)
256270
try:
257271
points = []
258-
shapes_file = f"{local_dir}/shapes.txt"
259-
with open(shapes_file, "r", encoding="utf-8", newline="") as f:
272+
with open(self.shapes_file, "r", encoding="utf-8", newline="") as f:
260273
for pos in index.get(shape_id, []):
261274
f.seek(pos)
262275
line = f.readline()
@@ -277,114 +290,109 @@ def _get_shape_points(self, shape_id, index):
277290
raise Exception(f"Failed to get shape points for {shape_id}: {e}") from e
278291

279292
def _create_routes_geojson(self):
280-
self.logger.info("Creating routes geojson")
281293
try:
282-
self.logger.debug("Loading shapes_index.pkl...")
283-
shapes_index_file = f"{local_dir}/shapes_index.pkl"
284-
shapes_file = f"{local_dir}/shapes.txt"
285-
trips_file = f"{local_dir}/trips.txt"
286-
routes_file = f"{local_dir}/routes.txt"
287-
stops_file = f"{local_dir}/stops.txt"
288-
stop_times_file = f"{local_dir}/stop_times.txt"
289-
with open(shapes_index_file, "rb") as idxf:
290-
shapes_index = pickle.load(idxf)
291-
self.logger.debug("Loaded index for %d shape_ids.", len(shapes_index))
292-
293-
with open(shapes_file, "r", encoding="utf-8", newline="") as f:
294-
header = f.readline()
295-
shapes_columns = next(csv.reader([header]))
296-
shapes_index["columns"] = shapes_columns
297-
298-
routes = {r["route_id"]: r for r in self._read_csv(routes_file)}
299-
self.logger.debug("Loaded %d routes.", len(routes))
294+
shapes_index = self._create_shapes_index()
295+
self.logger.info("Creating routes geojson (optimized for memory)")
300296

301-
trips = list(self._read_csv(trips_file))
302-
self.logger.debug("Loaded %d trips.", len(trips))
303-
304-
stops = {
297+
# Load stops into memory (usually not huge)
298+
# Used only if there is no shapes for a route
299+
coordinates_indexed_by_stop = {
305300
s["stop_id"]: (float(s["stop_lon"]), float(s["stop_lat"]))
306-
for s in self._read_csv(stops_file)
301+
for s in self._read_csv(self.stops_file)
307302
}
308-
self.logger.debug("Loaded %d stops.", len(stops))
309303

310-
stop_times_by_trip = {}
311-
self.logger.debug(
312-
"Grouping stop_times by trip_id for dataset %s",
313-
self.dataset_stable_id,
314-
)
315-
with open(stop_times_file, newline="", encoding="utf-8") as f:
316-
reader = csv.DictReader(f)
317-
for row in reader:
318-
stop_times_by_trip.setdefault(row["trip_id"], []).append(row)
319-
self.logger.debug(
320-
"Grouped stop_times for %d trips.", len(stop_times_by_trip)
321-
)
304+
# We want the shape of a route. To do that we find one trip for each route. We will assume that all
305+
# trips for a route have the same shape_id. If not I am not sure how to represent this in the route map
306+
# when we select a route.
307+
shape_map_indexed_by_route = {}
308+
trip_map_indexed_by_route = {}
309+
with open(self.trips_file, newline="", encoding="utf-8") as f:
310+
for row in csv.DictReader(f):
311+
trip_id = row["trip_id"]
312+
route_id = row["route_id"]
313+
shape_id = row.get("shape_id", "")
314+
if shape_id and route_id not in shape_map_indexed_by_route:
315+
shape_map_indexed_by_route[route_id] = shape_id
316+
if trip_id and route_id not in trip_map_indexed_by_route:
317+
trip_map_indexed_by_route[route_id] = trip_id
322318

323319
features = []
324320
missing_coordinates_routes = set()
325-
for i, (route_id, route) in enumerate(routes.items(), 1):
326-
if i % 100 == 0 or i == 1:
327-
self.logger.debug(
328-
"Processing route %d/%d (route_id: %s)",
329-
i,
330-
len(routes),
331-
route_id,
332-
)
333-
trip = next((t for t in trips if t["route_id"] == route_id), None)
334-
if not trip:
335-
self.logger.iunfo(
336-
" No trip found for route_id %s, skipping.", route_id
337-
)
338-
continue
339-
coordinates = []
340-
if "shape_id" in trip and trip["shape_id"]:
341-
self.logger.debug(
342-
" Using shape_id %s for route_id %s",
343-
trip["shape_id"],
344-
route_id,
345-
)
346-
coordinates = self._get_shape_points(trip["shape_id"], shapes_index)
347-
if isinstance(coordinates, dict) and "error" in coordinates:
348-
raise Exception(
349-
f"Error getting shape points for shape_id {trip['shape_id']}: {coordinates['error']}"
350-
)
351-
if not coordinates:
352-
trip_stop_times = stop_times_by_trip.get(trip["trip_id"], [])
353-
trip_stop_times.sort(key=lambda x: int(x["stop_sequence"]))
354-
coordinates = [
355-
stops[st["stop_id"]]
356-
for st in trip_stop_times
357-
if st["stop_id"] in stops
358-
]
359-
self.logger.debug(
360-
" Used %d stop coordinates for route_id %s",
361-
len(coordinates),
362-
route_id,
363-
)
364-
if not coordinates:
365-
missing_coordinates_routes.add(route_id)
366-
continue
367-
features.append(
368-
{
369-
"type": "Feature",
370-
"properties": {k: route[k] for k in route},
371-
"geometry": {"type": "LineString", "coordinates": coordinates},
372-
}
373-
)
321+
routes_geojson = f"{local_dir}/routes-output.geojson"
322+
with open(routes_geojson, "w", encoding="utf-8") as geojson_file:
323+
geojson_file.write('{"type": "FeatureCollection", "features": [\n')
324+
first = True
325+
with open(
326+
self.routes_file, newline="", encoding="utf-8"
327+
) as routes_file:
328+
for i, route in enumerate(csv.DictReader(routes_file), 1):
329+
route_id = route["route_id"]
330+
331+
shape_id = shape_map_indexed_by_route.get(route_id, "")
332+
333+
coordinates = []
334+
if shape_id:
335+
coordinates = self._get_shape_points(shape_id, shapes_index)
336+
if not coordinates:
337+
# We don't have the coordinates for the shape, fallback on stop_times and stops
338+
trip_id = trip_map_indexed_by_route.get(route_id, "")
339+
340+
if trip_id:
341+
trip_stop_times = self._get_trip_stop_times(trip_id)
342+
# We assume stop_times is already sorted by stop_sequence in the file.
343+
# According to the SPECS:
344+
# The values must increase along the trip but do not need to be consecutive.
345+
coordinates = [
346+
coordinates_indexed_by_stop[stop_id]
347+
for stop_id in trip_stop_times
348+
if stop_id in coordinates_indexed_by_stop
349+
]
350+
if not coordinates:
351+
missing_coordinates_routes.add(route_id)
352+
continue
353+
feature = {
354+
"type": "Feature",
355+
"properties": {k: route[k] for k in route},
356+
"geometry": {
357+
"type": "LineString",
358+
"coordinates": coordinates,
359+
},
360+
}
361+
362+
if not first:
363+
geojson_file.write(",\n")
364+
geojson_file.write(json.dumps(feature))
365+
first = False
366+
367+
if i % 100 == 0 or i == 1:
368+
self.logger.debug(
369+
"Processed route %d (route_id: %s)", i, route_id
370+
)
371+
372+
geojson_file.write("\n]}")
374373

375374
if missing_coordinates_routes:
376375
self.logger.info(
377376
"Routes without coordinates: %s", list(missing_coordinates_routes)
378377
)
379378
self.logger.debug(
380-
"Writing %d features to routes-output.geojson...", len(features)
379+
"Wrote %d features to routes-output.geojson", len(features)
381380
)
382-
routes_geojson = f"{local_dir}/routes-output.geojson"
383-
with open(routes_geojson, "w", encoding="utf-8") as f:
384-
json.dump({"type": "FeatureCollection", "features": features}, f)
385381
except Exception as e:
386382
raise Exception(f"Failed to create routes GeoJSON: {e}") from e
387383

384+
def _get_trip_stop_times(self, trip_id):
385+
# Lazy instantiation of the dictionary, because we may not need it al all if there is a shape.
386+
if self.stop_times_by_trip is None:
387+
self.stop_times_by_trip = {}
388+
with open(self.stop_times_file, newline="", encoding="utf-8") as f:
389+
for row in csv.DictReader(f):
390+
self.stop_times_by_trip.setdefault(row["trip_id"], []).append(
391+
row["stop_id"]
392+
)
393+
394+
return self.stop_times_by_trip.get(trip_id, [])
395+
388396
def _run_tippecanoe(self, input_file, output_file):
389397
self.logger.info("Running tippecanoe for input file %s", input_file)
390398
try:

0 commit comments

Comments
 (0)