Skip to content

Commit 0c1ffef

Browse files
Merge pull request #67 from NREL/rjf/omf-revisions
Rjf/omf revisions
2 parents 3959dd6 + 52fab62 commit 0c1ffef

File tree

5 files changed

+131
-82
lines changed

5 files changed

+131
-82
lines changed

rust/bambam-omf/src/collection/collector.rs

Lines changed: 23 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
use arrow::array::RecordBatch;
22
use chrono::NaiveDate;
33
use futures::stream::{self, StreamExt};
4+
use itertools::Itertools;
45
use object_store::{path::Path, ListResult, ObjectMeta, ObjectStore};
56
use parquet::arrow::arrow_reader::ArrowPredicate;
67
use parquet::arrow::arrow_reader::ArrowReaderOptions;
78
use parquet::arrow::async_reader::ParquetObjectReader;
89
use parquet::arrow::async_reader::ParquetRecordBatchStreamBuilder;
910
use rayon::prelude::*;
10-
use serde::de::DeserializeOwned;
1111
use std::sync::Arc;
1212
use std::time::Instant;
1313

@@ -184,47 +184,29 @@ impl OvertureMapsCollector {
184184
.collect::<Result<Vec<RecordBatch>, _>>()
185185
.map_err(|e| OvertureMapsCollectionError::RecordBatchRetrievalError { source: e })?;
186186

187-
// Deserialize batches into record types
188-
let records: Vec<Vec<OvertureRecord>> = match record_type {
189-
OvertureRecordType::Places => record_batches
190-
.par_iter()
191-
.map(deserialize_batch::<PlacesRecord>)
192-
.map(|records_result| {
193-
records_result
194-
.map(|records| records.into_iter().map(OvertureRecord::Places).collect())
195-
})
196-
.collect::<Result<Vec<_>, OvertureMapsCollectionError>>()?,
197-
OvertureRecordType::Buildings => record_batches
198-
.par_iter()
199-
.map(deserialize_batch::<BuildingsRecord>)
200-
.map(|records_result| {
201-
records_result
202-
.map(|records| records.into_iter().map(OvertureRecord::Buildings).collect())
203-
})
204-
.collect::<Result<Vec<_>, OvertureMapsCollectionError>>()?,
205-
OvertureRecordType::Segment => record_batches
206-
.par_iter()
207-
.map(deserialize_batch::<TransportationSegmentRecord>)
208-
.map(|records_result| {
209-
records_result
210-
.map(|records| records.into_iter().map(OvertureRecord::Segment).collect())
211-
})
212-
.collect::<Result<Vec<_>, OvertureMapsCollectionError>>()?,
213-
OvertureRecordType::Connector => record_batches
214-
.par_iter()
215-
.map(deserialize_batch::<TransportationConnectorRecord>)
216-
.map(|records_result| {
217-
records_result
218-
.map(|records| records.into_iter().map(OvertureRecord::Connector).collect())
219-
})
220-
.collect::<Result<Vec<_>, OvertureMapsCollectionError>>()?,
221-
};
222-
log::info!("Deserialization time {:?}", start_collection.elapsed());
187+
let start_deserialization = Instant::now();
188+
let records: Vec<OvertureRecord> = record_batches
189+
.par_iter()
190+
.map(|batch| match record_type {
191+
OvertureRecordType::Places => record_type.process_batch::<PlacesRecord>(batch),
192+
OvertureRecordType::Buildings => {
193+
record_type.process_batch::<BuildingsRecord>(batch)
194+
}
195+
OvertureRecordType::Segment => {
196+
record_type.process_batch::<TransportationSegmentRecord>(batch)
197+
}
198+
OvertureRecordType::Connector => {
199+
record_type.process_batch::<TransportationConnectorRecord>(batch)
200+
}
201+
})
202+
.collect::<Result<Vec<_>, _>>()?
203+
.into_iter()
204+
.flatten()
205+
.collect_vec();
223206

224-
// Flatten the collection
225-
let flatten_records = records.into_iter().flatten().collect();
207+
log::info!("Deserialization time {:?}", start_deserialization.elapsed());
226208
log::info!("Total time {:?}", start_collection.elapsed());
227-
Ok(flatten_records)
209+
Ok(records)
228210
}
229211

230212
pub fn collect_from_release(
@@ -237,21 +219,12 @@ impl OvertureMapsCollector {
237219
ReleaseVersion::Latest => self.get_latest_release()?,
238220
other => String::from(other),
239221
};
240-
log::info!("Collecting OvertureMaps records from release {release_str}");
222+
log::info!("Collecting OvertureMaps {record_type} records from release {release_str}");
241223
let path = Path::from(record_type.format_url(release_str));
242224
self.collect_from_path(path, record_type, row_filter_config)
243225
}
244226
}
245227

246-
/// Deserialize recordBatch into type T
247-
fn deserialize_batch<T>(record_batch: &RecordBatch) -> Result<Vec<T>, OvertureMapsCollectionError>
248-
where
249-
T: DeserializeOwned,
250-
{
251-
serde_arrow::from_record_batch(record_batch)
252-
.map_err(|e| OvertureMapsCollectionError::DeserializeError(format!("Serde error: {e}")))
253-
}
254-
255228
#[cfg(test)]
256229
mod test {
257230
use crate::collection::{

rust/bambam-omf/src/collection/filter/bbox_row_predicate.rs

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use arrow::{
55
};
66
use parquet::arrow::arrow_reader::ArrowPredicate;
77

8+
/// tests if a row is contained within a bounding box.
89
pub struct BboxRowPredicate {
910
bbox: Bbox,
1011
projection_mask: parquet::arrow::ProjectionMask,
@@ -24,11 +25,13 @@ impl ArrowPredicate for BboxRowPredicate {
2425
&self.projection_mask
2526
}
2627

28+
/// tests the bounding box of each row in the record batch, filtering entries
29+
/// that are not fully-contained.
2730
fn evaluate(
2831
&mut self,
2932
batch: arrow::array::RecordBatch,
3033
) -> Result<arrow::array::BooleanArray, arrow::error::ArrowError> {
31-
let struct_array = batch
34+
let bbox_struct = batch
3235
.column_by_name("bbox")
3336
.ok_or(ArrowError::ParquetError(String::from(
3437
"`bbox` column not found",
@@ -39,36 +42,47 @@ impl ArrowPredicate for BboxRowPredicate {
3942
"Cannot cast column `bbox` to StructArray type",
4043
)))?;
4144

42-
let x_min_col = struct_array
43-
.column_by_name("xmin")
44-
.ok_or(ArrowError::ParquetError(String::from(
45-
"`bbox.xmin` column not found",
46-
)))?
47-
.as_any()
48-
.downcast_ref::<Float32Array>()
49-
.ok_or(ArrowError::ParquetError(String::from(
50-
"Cannot cast column `bbox.xmin` to Float32Array type",
51-
)))?;
45+
let xmins = get_column::<Float32Array>("xmin", bbox_struct)?;
46+
let ymins = get_column::<Float32Array>("ymin", bbox_struct)?;
47+
let xmaxs = get_column::<Float32Array>("xmax", bbox_struct)?;
48+
let ymaxs = get_column::<Float32Array>("ymax", bbox_struct)?;
5249

53-
let y_min_col = struct_array
54-
.column_by_name("ymin")
55-
.ok_or(ArrowError::ParquetError(String::from(
56-
"`bbox.ymin` column not found",
57-
)))?
58-
.as_any()
59-
.downcast_ref::<Float32Array>()
60-
.ok_or(ArrowError::ParquetError(String::from(
61-
"Cannot cast column `bbox.ymin` to Float32Array type",
62-
)))?;
63-
64-
let boolean_values: Vec<bool> = (0..struct_array.len())
65-
.map(|i| {
66-
self.bbox.xmin < x_min_col.value(i)
67-
&& x_min_col.value(i) < self.bbox.xmax
68-
&& self.bbox.ymin < y_min_col.value(i)
69-
&& y_min_col.value(i) < self.bbox.ymax
70-
})
50+
let boolean_values: Vec<bool> = (0..bbox_struct.len())
51+
.map(|i| within_box(i, xmins, ymins, xmaxs, ymaxs, &self.bbox))
7152
.collect();
7253
Ok(BooleanArray::from(boolean_values))
7354
}
7455
}
56+
57+
/// helper function to get a column by name from a struct array and return it as
58+
/// the expected type.
59+
fn get_column<'b, T>(col: &str, struct_array: &'b StructArray) -> Result<&'b T, ArrowError>
60+
where
61+
T: 'static,
62+
{
63+
struct_array
64+
.column_by_name(col)
65+
.ok_or(ArrowError::ParquetError(format!(
66+
"'bbox.{col}' column not found"
67+
)))?
68+
.as_any()
69+
.downcast_ref::<T>()
70+
.ok_or(ArrowError::ParquetError(format!(
71+
"Cannot cast column 'bbox.{col}' to expected type"
72+
)))
73+
}
74+
75+
/// helper function to test whether a given row's values are contained within the bounding box.
76+
fn within_box(
77+
index: usize,
78+
xmins: &Float32Array,
79+
ymins: &Float32Array,
80+
xmaxs: &Float32Array,
81+
ymaxs: &Float32Array,
82+
bbox: &Bbox,
83+
) -> bool {
84+
bbox.xmin <= xmins.value(index)
85+
&& xmaxs.value(index) <= bbox.xmax
86+
&& bbox.ymin <= ymins.value(index)
87+
&& ymaxs.value(index) <= bbox.ymax
88+
}

rust/bambam-omf/src/collection/record/overture_record.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,27 @@ pub enum OvertureRecord {
99
Segment(TransportationSegmentRecord),
1010
Connector(TransportationConnectorRecord),
1111
}
12+
13+
impl From<PlacesRecord> for OvertureRecord {
14+
fn from(value: PlacesRecord) -> Self {
15+
Self::Places(value)
16+
}
17+
}
18+
19+
impl From<BuildingsRecord> for OvertureRecord {
20+
fn from(value: BuildingsRecord) -> Self {
21+
Self::Buildings(value)
22+
}
23+
}
24+
25+
impl From<TransportationSegmentRecord> for OvertureRecord {
26+
fn from(value: TransportationSegmentRecord) -> Self {
27+
Self::Segment(value)
28+
}
29+
}
30+
31+
impl From<TransportationConnectorRecord> for OvertureRecord {
32+
fn from(value: TransportationConnectorRecord) -> Self {
33+
Self::Connector(value)
34+
}
35+
}

rust/bambam-omf/src/collection/record/record_type.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
use arrow::array::RecordBatch;
2+
use serde::de::DeserializeOwned;
3+
4+
use crate::collection::{OvertureMapsCollectionError, OvertureRecord};
5+
16
pub enum OvertureRecordType {
27
Places,
38
Buildings,
@@ -22,4 +27,31 @@ impl OvertureRecordType {
2227
}
2328
}
2429
}
30+
31+
/// processes an arrow [RecordBatch] into an [OvertureRecord] collection,
32+
/// deserializing into the underlying row type struct along the way.
33+
pub fn process_batch<R>(
34+
&self,
35+
record_batch: &RecordBatch,
36+
) -> Result<Vec<OvertureRecord>, OvertureMapsCollectionError>
37+
where
38+
R: DeserializeOwned + Into<OvertureRecord>,
39+
{
40+
let as_rows: Vec<R> = serde_arrow::from_record_batch(record_batch).map_err(|e| {
41+
OvertureMapsCollectionError::DeserializeError(format!("Serde error: {e}"))
42+
})?;
43+
let as_result: Vec<OvertureRecord> = as_rows.into_iter().map(Into::into).collect();
44+
Ok(as_result)
45+
}
46+
}
47+
48+
impl std::fmt::Display for OvertureRecordType {
49+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50+
match self {
51+
Self::Places => write!(f, "Places"),
52+
Self::Buildings => write!(f, "Buildings"),
53+
Self::Segment => write!(f, "Segment"),
54+
Self::Connector => write!(f, "Connector"),
55+
}
56+
}
2557
}

script/setup_test_bambam.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
print("importing osmnx, compass")
22
import osmnx as ox
33
from nrel.routee.compass.io import generate_compass_dataset
4-
4+
from nrel.routee.compass.io.generate_dataset import GeneratePipelinePhase
55

66
if __name__ == "__main__":
7+
8+
phases = [
9+
GeneratePipelinePhase.CONFIG,
10+
GeneratePipelinePhase.GRAPH
11+
]
12+
713
print("downloading graph")
814
g = ox.graph_from_place("Denver, Colorado, USA", network_type="drive")
915
print("processing graph into compass dataset")
10-
generate_compass_dataset(g, output_directory="denver_co")
16+
generate_compass_dataset(g, output_directory="denver_co", phases=phases)
1117

1218
# Boulder graph for GTFS
1319
g = ox.graph_from_place("Boulder, Colorado, USA", network_type="drive")
1420
print("processing graph into compass dataset")
15-
generate_compass_dataset(g, output_directory="boulder_co")
21+
generate_compass_dataset(g, output_directory="boulder_co", phases=phases)

0 commit comments

Comments
 (0)