Skip to content

Commit 2a956ae

Browse files
authored
fix(tableau/ingestion): allow for changes in Glides schema (#677)
* Write failing test * Isolate problem test case * Pass test * Tweak datetime regex * Allow glides schema to be out of order
1 parent 9209711 commit 2a956ae

File tree

3 files changed

+72
-58
lines changed

3 files changed

+72
-58
lines changed

src/lamp_py/ingestion/glides.py

Lines changed: 28 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@
88
import dataframely as dy
99
import polars as pl
1010
import pyarrow
11-
import pyarrow.dataset as pd
1211
import pyarrow.parquet as pq
13-
import pyarrow.compute as pc
14-
from dateutil.relativedelta import relativedelta
1512

1613
from lamp_py.aws.s3 import download_file, upload_file
1714
from lamp_py.aws.kinesis import KinesisReader
@@ -24,7 +21,7 @@
2421
)
2522

2623
RFC3339_DATE_REGEX = r"^20(?:([1-3][0-9]-[0-1][0-9]-[0-3][0-9]))" # up to 2039-19-39
27-
RFC3339_DATETIME_REGEX = RFC3339_DATE_REGEX + r"T([0-2][0-9]:[0-5][0-9]:[0-5][0-9](?:\.\d+)?)(Z|[\+-]\d{2}:\d{2})?$"
24+
RFC3339_DATETIME_REGEX = RFC3339_DATE_REGEX + r"[T ]([0-2][0-9]:[0-5][0-9]:[0-5][0-9](?:\.\d+)?)(Z|[\+-]\d{2}:\d{2})?$"
2825
GTFS_TIME_REGEX = r"^([0-9]{2}):([0-5][0-9]):([0-5][0-9])$" # clock can be greater than 24 hours
2926

3027
user = dy.Struct(
@@ -67,7 +64,7 @@ class GlidesRecord(dy.Schema):
6764
id = dy.String()
6865
type = dy.String()
6966
time = dy.Datetime( # in %Y-%m-%dT%H:%M:%S%:z format before serialization
70-
min=datetime(2024, 1, 1), max=datetime(2039, 12, 31) # within Python's serializable range
67+
min=datetime(2024, 1, 1), max=datetime(2039, 12, 31), time_unit="ms" # within Python's serializable range
7168
)
7269
source = dy.String()
7370
specversion = dy.String()
@@ -213,57 +210,36 @@ def download_remote(self) -> None:
213210
download_file(object_path=self.remote_path, file_name=self.local_path)
214211

215212
@abstractmethod
216-
def convert_records(self) -> pd.Dataset:
213+
def convert_records(self) -> dy.DataFrame[GlidesRecord]:
217214
"""Convert incoming records into a flattened table of records"""
218215

219216
def append_records(self) -> None:
220217
"""Add incoming records to a local parquet file"""
221218
process_logger = ProcessLogger(process_name="append_glides_records", type=self.type)
222219
process_logger.log_start()
223220

224-
new_dataset = self.convert_records()
221+
new_dataset = self.convert_records().lazy()
225222

226223
if os.path.exists(self.local_path):
227-
remote_records = pd.dataset(self.local_path, schema=self.get_table_schema)
228-
joined_ds = pd.dataset([new_dataset, remote_records])
224+
remote_records = self.table_schema.scan_parquet(self.local_path, validation="allow")
225+
joined_ds = pl.union([new_dataset, remote_records])
229226
else:
230227
joined_ds = new_dataset
231228

232229
process_logger.add_metadata(
233-
new_records=new_dataset.count_rows(),
234-
total_records=joined_ds.count_rows(),
230+
new_records=new_dataset.select("time").count().collect().item(),
231+
total_records=joined_ds.select("time").count().collect().item(),
235232
)
236233

237-
now = datetime.now()
238-
start = datetime(2024, 1, 1)
239-
240234
with tempfile.TemporaryDirectory() as tmp_dir:
241235

242236
new_path = os.path.join(tmp_dir, self.base_filename)
243-
row_group_count = 0
244-
with pq.ParquetWriter(new_path, schema=self.get_table_schema) as writer:
245-
while start < now:
246-
end = start + relativedelta(months=1)
247-
if end < now:
248-
row_group = pl.DataFrame(
249-
joined_ds.filter((pc.field("time") >= start) & (pc.field("time") < end)).to_table()
250-
)
251-
252-
else:
253-
row_group = pl.DataFrame(joined_ds.filter((pc.field("time") >= start)).to_table())
254-
255-
if not row_group.is_empty():
256-
unique_table = (
257-
row_group.unique(keep="first").sort(by=["time"]).to_arrow().cast(self.get_table_schema)
258-
)
259-
260-
row_group_count += 1
261-
writer.write_table(unique_table)
262-
263-
start = end
264-
265-
os.replace(new_path, self.local_path)
266-
process_logger.add_metadata(row_group_count=row_group_count)
237+
sorted_ds = joined_ds.unique().sort("time")
238+
valid = process_logger.log_dataframely_filter_results(*self.table_schema.filter(sorted_ds))
239+
if not valid.is_empty():
240+
pq.write_table(valid.to_arrow().cast(self.get_table_schema), new_path)
241+
os.replace(new_path, self.local_path)
242+
process_logger.add_metadata(row_count=pq.read_metadata(self.local_path).num_rows)
267243

268244
process_logger.log_complete()
269245

@@ -290,15 +266,17 @@ def __init__(self) -> None:
290266
def unique_key(self) -> str:
291267
return "changes"
292268

293-
def convert_records(self) -> pd.Dataset:
269+
def convert_records(self) -> dy.DataFrame[GlidesRecord]:
294270
process_logger = ProcessLogger(process_name="convert_records", type=self.type)
295271
process_logger.log_start()
296272

297273
editors_table = pyarrow.Table.from_pylist(self.records, schema=self.get_event_schema)
298274
editors_table = flatten_table_schema(editors_table)
299275
editors_table = explode_table_column(editors_table, "data.changes")
300276
editors_table = flatten_table_schema(editors_table)
301-
editors_dataset = pd.dataset(editors_table)
277+
editors_dataset = process_logger.log_dataframely_filter_results(
278+
*EditorChangesTable.filter(pl.DataFrame(editors_table))
279+
)
302280

303281
process_logger.log_complete()
304282
return editors_dataset
@@ -322,12 +300,14 @@ def __init__(self) -> None:
322300
def unique_key(self) -> str:
323301
return "operator"
324302

325-
def convert_records(self) -> pd.Dataset:
303+
def convert_records(self) -> dy.DataFrame[GlidesRecord]:
326304
process_logger = ProcessLogger(process_name="convert_records", type=self.type)
327305
process_logger.log_start()
328306
osi_table = pyarrow.Table.from_pylist(self.records, schema=self.get_event_schema)
329307
osi_table = flatten_table_schema(osi_table)
330-
osi_dataset = pd.dataset(osi_table)
308+
osi_dataset = process_logger.log_dataframely_filter_results(
309+
*OperatorSignInsTable.filter(pl.DataFrame(osi_table))
310+
)
331311

332312
process_logger.log_complete()
333313
return osi_dataset
@@ -348,7 +328,7 @@ def __init__(self) -> None:
348328
def unique_key(self) -> str:
349329
return "tripUpdates"
350330

351-
def convert_records(self) -> pd.Dataset:
331+
def convert_records(self) -> dy.DataFrame[GlidesRecord]:
352332
def flatten_multitypes(record: Dict) -> Dict:
353333
"""
354334
For each update in a record, flatten out the objects in "cars",
@@ -374,7 +354,7 @@ def flatten_multitypes(record: Dict) -> Dict:
374354
tu_table = flatten_table_schema(tu_table)
375355
tu_table = explode_table_column(tu_table, "data.tripUpdates")
376356
tu_table = flatten_table_schema(tu_table)
377-
tu_dataset = pd.dataset(tu_table)
357+
tu_dataset = process_logger.log_dataframely_filter_results(*TripUpdatesTable.filter(pl.DataFrame(tu_table)))
378358

379359
process_logger.log_complete()
380360
return tu_dataset
@@ -398,13 +378,15 @@ def __init__(self) -> None:
398378
def unique_key(self) -> str:
399379
return "tripKey"
400380

401-
def convert_records(self) -> pd.Dataset:
381+
def convert_records(self) -> dy.DataFrame[GlidesRecord]:
402382
process_logger = ProcessLogger(process_name="convert_records", type=self.type)
403383
process_logger.log_start()
404384

405385
tu_table = pyarrow.Table.from_pylist(self.records, schema=self.get_event_schema)
406386
tu_table = flatten_table_schema(tu_table)
407-
tu_dataset = pd.dataset(tu_table)
387+
tu_dataset = process_logger.log_dataframely_filter_results(
388+
*VehicleTripAssignmentTable.filter(pl.DataFrame(tu_table))
389+
)
408390

409391
process_logger.log_complete()
410392
return tu_dataset

src/lamp_py/tableau/jobs/glides.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def create_trips_updated_glides_parquet(job: HyperJob, num_files: Optional[int])
128128
# pl.col("data.tripUpdates.endTime").str.to_time("%H:%M:%S", strict=False),
129129
)
130130

131-
writer.write_table(polars_df.to_arrow())
131+
writer.write_table(polars_df.select(job.output_processed_schema.names).to_arrow())
132132

133133

134134
def create_operator_signed_in_glides_parquet(job: HyperJob, num_files: Optional[int]) -> None:
@@ -166,7 +166,7 @@ def create_operator_signed_in_glides_parquet(job: HyperJob, num_files: Optional[
166166
pl.col("time").dt.convert_time_zone(time_zone="US/Eastern").dt.replace_time_zone(None),
167167
)
168168

169-
writer.write_table(polars_df.to_arrow())
169+
writer.write_table(polars_df.select(job.output_processed_schema.names).to_arrow())
170170

171171

172172
class HyperGlidesTripUpdates(HyperJob):

tests/ingestion/test_glides.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from datetime import datetime
12
from os import remove
23
from pathlib import Path
34
from queue import Queue
@@ -7,7 +8,7 @@
78
import dataframely as dy
89
import pytest
910
import polars as pl
10-
from polars.testing import assert_frame_equal
11+
import pyarrow.parquet as pq
1112

1213
from lamp_py.ingestion.glides import (
1314
GlidesConverter,
@@ -37,9 +38,14 @@ def test_convert_records(dy_gen: dy.random.Generator, converter: GlidesConverter
3738
converter.records = converter.record_schema.sample(
3839
num_rows=num_rows,
3940
generator=dy_gen,
41+
overrides={
42+
"time": dy_gen.sample_datetime(
43+
num_rows, min=datetime(2024, 1, 1), max=datetime(2039, 12, 31), time_unit="us"
44+
).cast(pl.Datetime(time_unit="ms"))
45+
},
4046
).to_dicts()
4147

42-
table = pl.scan_pyarrow_dataset(converter.convert_records())
48+
table = converter.convert_records()
4349

4450
assert not converter.table_schema.validate(table).is_empty()
4551
assert converter.table_schema.validate(table).select("id").unique().height == num_rows # all records
@@ -48,8 +54,13 @@ def test_convert_records(dy_gen: dy.random.Generator, converter: GlidesConverter
4854

4955
@pytest.mark.parametrize(
5056
["column_transformations"],
51-
[({},), ({"id": pl.col("id")},), ({"new_col": pl.lit(1)},)],
52-
ids=["no-remote-records", "same-schema", "dropped-column"],
57+
[
58+
({},),
59+
({"id": pl.col("id")},),
60+
({"new_col": pl.lit(1)},),
61+
({"time": pl.col("time").cast(pl.Datetime(time_unit="us")).dt.offset_by("1us")},),
62+
],
63+
ids=["no-remote-records", "same-schema", "dropped-column", "truncated-timestamp"],
5364
)
5465
@pytest.mark.parametrize(
5566
[
@@ -67,26 +78,42 @@ def test_append_records(
6778
dy_gen: dy.random.Generator,
6879
converter: GlidesConverter,
6980
tmp_path: Path,
70-
column_transformations: dict,
81+
column_transformations: dict[str, pl.Expr],
7182
num_rows: int = 5,
7283
) -> None:
73-
"""It writes all records locally."""
84+
"""It writes all records locally using the table schema."""
7485
converter.records = converter.record_schema.sample(
7586
num_rows=num_rows,
7687
generator=dy_gen,
88+
overrides={
89+
"time": dy_gen.sample_datetime(
90+
num_rows, min=datetime(2024, 1, 1), max=datetime(2039, 12, 31), time_unit="us"
91+
).cast(pl.Datetime(time_unit="ms"))
92+
},
7793
).to_dicts()
7894

7995
converter.local_path = tmp_path.joinpath(converter.base_filename).as_posix()
8096

81-
expectation = pl.scan_pyarrow_dataset(converter.convert_records()).collect()
97+
expectation = converter.convert_records()
8298

99+
remote_records_height = 0
83100
if column_transformations:
84-
remote_records = expectation.with_columns(**column_transformations)
101+
remote_records = converter.table_schema.sample(
102+
num_rows,
103+
generator=dy_gen,
104+
overrides={
105+
"time": dy_gen.sample_datetime(
106+
num_rows, min=datetime(2024, 1, 1), max=datetime(2039, 12, 31), time_unit="us"
107+
).cast(pl.Datetime(time_unit="ms"))
108+
},
109+
).with_columns(**column_transformations)
85110
remote_records.write_parquet(converter.local_path)
111+
remote_records_height = remote_records.height
86112

87113
converter.append_records()
88114

89-
assert_frame_equal(expectation, pl.read_parquet(converter.local_path), check_row_order=False)
115+
assert pq.read_schema(converter.local_path) == converter.get_table_schema
116+
assert pq.read_metadata(converter.local_path).num_rows == expectation.height + remote_records_height
90117

91118

92119
@pytest.mark.parametrize(
@@ -102,13 +129,18 @@ def test_append_records(
102129
ids=["editor-changes", "operator-sign-ins", "trip-updates", "vehicle-trip-assignments"],
103130
)
104131
def test_ingest_glides_events(
105-
converter: GlidesConverter, dy_gen: dy.random.Generator, mocker: MockerFixture, events_per_converter: int = 500
132+
converter: GlidesConverter, dy_gen: dy.random.Generator, mocker: MockerFixture, events_per_converter: int = 50
106133
) -> None:
107134
"""It routes events to correct converter and writes them to specified storage."""
108135
test_records = (
109136
converter.record_schema.sample( # generate test records
110137
num_rows=events_per_converter,
111138
generator=dy_gen,
139+
overrides={
140+
"time": dy_gen.sample_datetime(
141+
events_per_converter, min=datetime(2024, 1, 1), max=datetime(2039, 12, 31), time_unit="us"
142+
).cast(pl.Datetime(time_unit="ms"))
143+
},
112144
)
113145
.with_columns(
114146
time=pl.col("time").dt.strftime("%Y-%m-%dT%H:%M:%SZ")

0 commit comments

Comments
 (0)