diff --git a/.github/workflows/rust-release.yaml b/.github/workflows/rust-release.yaml index 28517eb2..e990d5df 100644 --- a/.github/workflows/rust-release.yaml +++ b/.github/workflows/rust-release.yaml @@ -26,6 +26,9 @@ jobs: - name: Run unit tests run: cargo test --workspace --manifest-path=rust/Cargo.toml -- --skip test_e2e + - name: Publish bambam-core + run: cargo publish -p bambam-core --manifest-path=rust/Cargo.toml --token "${{ secrets.CRATES_IO_TOKEN }}" + - name: Publish bambam-osm run: cargo publish -p bambam-osm --manifest-path=rust/Cargo.toml --token "${{ secrets.CRATES_IO_TOKEN }}" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 01dca801..2ae48472 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -7,6 +7,7 @@ members = [ "bambam-omf", "bambam-osm", "bambam-py", + "bambam-core" ] [workspace.dependencies] diff --git a/rust/bambam-core/Cargo.toml b/rust/bambam-core/Cargo.toml new file mode 100644 index 00000000..6cdff2d5 --- /dev/null +++ b/rust/bambam-core/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "bambam-core" +version = "0.2.3" +edition = "2021" +license = "BSD-3-Clause" +exclude = ["test", "**/.DS_Store", "target/"] +readme = "README.md" +repository = "https://github.com/NREL/bambam" +documentation = "https://docs.rs/bambam" +description = "The Behavior and Advanced Mobility Big Access Model" +keywords = [ + "nrel", + "access-model", + "accessibility", + "multimodal", + "transit", +] +categories = ["command-line-utilities", "science", "science::geo"] + +[dependencies] +itertools = { workspace = true } +routee-compass = { workspace = true } +routee-compass-core = { workspace = true } +routee-compass-powertrain = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +uom = { workspace = true } +wkt = { workspace = true } +geo = { workspace = true } +geo-types = { workspace = true } +hex = { workspace = true } +geo-traits = { workspace = true } +geojson = { workspace = true } +wkb = { workspace = true } +rstar = { workspace = true } +chrono = { workspace = true } \ No newline at end of file diff --git a/rust/bambam-core/README.md b/rust/bambam-core/README.md new file mode 100644 index 00000000..e69de29b diff --git a/rust/bambam-core/src/lib.rs b/rust/bambam-core/src/lib.rs new file mode 100644 index 00000000..70560358 --- /dev/null +++ b/rust/bambam-core/src/lib.rs @@ -0,0 +1,2 @@ +pub mod model; +pub mod util; diff --git a/rust/bambam-core/src/model/bambam_field.rs b/rust/bambam-core/src/model/bambam_field.rs new file mode 100644 index 00000000..8a236689 --- /dev/null +++ b/rust/bambam-core/src/model/bambam_field.rs @@ -0,0 +1,326 @@ +//! Fields and types assigned to the JSON output during bambam runs. +//! +//! # Examples +//! +//! ### Aggregate Data Rows +//! +//! ```json +//! { +//! "opportunity_format": "aggregate", +//! "opportunity_totals": {}, +//! "activity_types": [], +//! "info": { +//! "opportunity_runtime": "hh:mm:ss", +//! "mep_runtime": "hh:mm:ss", +//! "tree_size": 0, +//! } +//! "bin": { +//! 10: { +//! "isochrone": {}, +//! "opportunities" {}, +//! "mep": {}, +//! "info": { +//! "time_bin": { .. }, +//! "bin_runtime": +//! }, +//! } +//! } +//! } +//! ``` +//! +//! ### Disaggregate Data Rows +//! ```json +//! { +//! "opportunity_format": "disaggregate", +//! "opportunity_totals": {}, +//! "activity_types": [], +//! "opportunities": { +//! "{EdgeListId}-{EdgeId}": { +//! "counts": {}, +//! "state": [] +//! } +//! } +//! } +//! ``` +//! +use crate::model::TimeBin; +use itertools::Itertools; +use routee_compass::plugin::output::OutputPluginError; +use serde::de::DeserializeOwned; +use serde_json::{json, Value}; + +pub const TIME_BINS: &str = "bin"; +pub const TIME_BIN: &str = "time_bin"; +pub const INFO: &str = "info"; +pub const MODE: &str = "mode"; +pub const ISOCHRONE: &str = "isochrone"; +pub const ISOCHRONE_FORMAT: &str = "isochrone_format"; +pub const TREE_SIZE: &str = "tree_size"; +pub const ACTIVITY_TYPES: &str = "activity_types"; +pub const OPPORTUNITIES: &str = "opportunities"; +pub const OPPORTUNITY_COUNTS: &str = "opportunity_counts"; +pub const OPPORTUNITY_ORIENTATION: &str = "opportunity_orientation"; +pub const OPPORTUNITY_FORMAT: &str = "opportunity_format"; +pub const OPPORTUNITY_TOTALS: &str = "opportunity_totals"; +pub const VEHICLE_STATE: &str = "vehicle_state"; +pub const OPP_FMT_AGGREGATE: &str = "aggregate"; +pub const OPP_FMT_DISAGGREGATE: &str = "disaggregate"; +pub const OPPORTUNITY_PLUGIN_RUNTIME: &str = "opportunity_runtime"; +pub const OPPORTUNITY_BIN_RUNTIME: &str = "bin_runtime"; + +pub mod get { + use itertools::Itertools; + use routee_compass::plugin::output::OutputPluginError; + use routee_compass_core::model::{ + network::{EdgeId, EdgeListId, VertexId}, + state::StateVariable, + }; + use serde::de::DeserializeOwned; + use serde_json::Value; + use std::collections::HashMap; + + use crate::model::{ + bambam_field::as_usize, + output_plugin::{ + isochrone::IsochroneOutputFormat, + opportunity::{OpportunityFormat, OpportunityOrientation}, + }, + }; + + pub fn mode(value: &Value) -> Result { + let path = ["request", super::MODE]; + super::get_nested(value, &path).map_err(|e| { + let dot_path = path.join("."); + OutputPluginError::OutputPluginFailed(format!( + "failure retrieving 'mode' value from '{dot_path}': {e}" + )) + }) + } + pub fn activity_types(value: &Value) -> Result, OutputPluginError> { + get_from_value(super::ACTIVITY_TYPES, value) + } + pub fn isochrone_format(value: &Value) -> Result { + get_from_value(super::ISOCHRONE_FORMAT, value) + } + pub fn opportunity_format(value: &Value) -> Result { + get_from_value(super::OPPORTUNITY_FORMAT, value) + } + pub fn opportunity_orientation( + value: &Value, + ) -> Result { + get_from_value(super::OPPORTUNITY_ORIENTATION, value) + } + pub fn disaggregate_vertex_id(value: &str) -> Result { + let id: usize = super::as_usize(value)?; + Ok(VertexId(id)) + } + pub fn disaggregate_edge_id(value: &str) -> Result<(EdgeListId, EdgeId), OutputPluginError> { + match value.split("-").collect_vec()[..] { + [] => Err(OutputPluginError::OutputPluginFailed("disaggregate edge identifier is empty".to_string())), + [edge_list_str, edge_str] => { + let edge_list_id = EdgeListId(as_usize(edge_list_str)?); + let edge_id = EdgeId(as_usize(edge_str)?); + Ok((edge_list_id, edge_id)) + }, + _ => Err(OutputPluginError::OutputPluginFailed(format!("disaggregate edge identifier is malformed, expected '-', found '{value}'"))) + } + } + pub fn totals(value: &Value) -> Result, OutputPluginError> { + get_from_value(super::OPPORTUNITY_TOTALS, value) + } + pub fn counts(value: &Value) -> Result, OutputPluginError> { + get_from_value(super::OPPORTUNITY_COUNTS, value) + } + pub fn state(value: &Value) -> Result, OutputPluginError> { + get_from_value(super::VEHICLE_STATE, value) + } + + /// helper for deserializing fields from a JSON value in a deserializable type + fn get_from_value(field: &str, value: &Value) -> Result + where + T: DeserializeOwned, + { + let value = value.get(field).ok_or_else(|| { + OutputPluginError::InternalError(format!("cannot find '{field}' in output row")) + })?; + serde_json::from_value(value.clone()).map_err(|e| { + OutputPluginError::OutputPluginFailed(format!( + "found '{field}' in output row but cannot deserialize due to: {e}" + )) + }) + } +} + +mod set {} + +/// gets a deserialized value from a json object at some path. not compatible with json arrays. +pub fn get_nested(json: &Value, path: &[&str]) -> Result { + let mut cursor = json; + for k in path { + match cursor.get(k) { + Some(child) => { + cursor = child; + } + None => return Err(nested_error("get", path.to_vec(), k, cursor)), + } + } + let result = serde_json::from_value(cursor.clone()) + .map_err(|e| format!("unable to deserialize value '{cursor}': {e}"))?; + Ok(result) +} + +/// inserts a json value into a json object at some path, adding any missing parent objects +/// along the way. not compatible with json arrays. +pub fn insert_nested_with_parents( + json: &mut Value, + path: &[&str], + key: &str, + value: Value, + overwrite: bool, +) -> Result<(), String> { + let parents = path.to_vec(); + for i in 0..parents.len() { + let key = parents[i]; + insert_nested(json, &parents[0..i], key, json![{}], false)?; + } + insert_nested(json, path, key, value, overwrite) +} + +/// inserts a json value into a json object at some path. not compatible with json arrays. +pub fn insert_nested( + json: &mut Value, + path: &[&str], + key: &str, + value: Value, + overwrite: bool, +) -> Result<(), String> { + let mut cursor = json; + for k in path { + if cursor.get(k).is_none() { + return Err(nested_error("insert", path.to_vec(), k, cursor)); + }; + match cursor.get_mut(k) { + Some(child) => { + cursor = child; + } + None => unreachable!("invariant: already None-checked above"), + } + } + let exists = cursor.get(key).is_some(); + if exists && !overwrite { + Ok(()) + } else { + cursor[key] = value; + Ok(()) + } +} + +/// assures that the structure exists for a time bin. +/// +/// +/// with time bin [0, 10]: +/// +/// { +/// "bin": { +/// "10": { +/// "info": { "time_bin": { .. } }, +/// } +/// } +/// } +pub fn scaffold_time_bin(json: &mut Value, time_bin: &TimeBin) -> Result<(), String> { + if json.get(TIME_BINS).is_none() { + json[TIME_BINS] = json![{}]; + } + let time_bin_key = time_bin.key(); + insert_nested(json, &[TIME_BINS], &time_bin_key, json![{}], false)?; + insert_nested( + json, + &[TIME_BINS, &time_bin_key], + INFO, + json![{ TIME_BIN: json![time_bin] }], + false, + )?; + Ok(()) +} + +type TimeBinsIter<'a> = Box> + 'a>; +type TimeBinsIterMut<'a> = Box> + 'a>; + +pub fn get_time_bins(output: &serde_json::Value) -> Result, String> { + let bins_value = output + .get(TIME_BINS) + .ok_or_else(|| field_error(vec![TIME_BINS]))?; + let bins = bins_value + .as_object() + .ok_or_else(|| type_error(vec![TIME_BINS], String::from("JSON object")))? + .values() + .map(|v| get_nested(v, &[INFO, TIME_BIN])) + .collect::, _>>()?; + Ok(bins) +} + +pub fn time_bins_iter(output: &serde_json::Value) -> Result, String> { + let bins_value = output + .get(TIME_BINS) + .ok_or_else(|| field_error(vec![TIME_BINS]))?; + let bins = bins_value + .as_object() + .ok_or_else(|| type_error(vec![TIME_BINS], String::from("JSON object")))? + .values() + .map(|v| { + let time_bin = get_nested(v, &[INFO, TIME_BIN]); + time_bin.map(|t| (t, v)) + }); + Ok(Box::new(bins)) +} + +pub fn time_bins_iter_mut(output: &mut serde_json::Value) -> Result, String> { + let bins_value = output + .get_mut(TIME_BINS) + .ok_or_else(|| field_error(vec![TIME_BINS]))?; + let bins = bins_value + .as_object_mut() + .ok_or_else(|| type_error(vec![TIME_BINS], String::from("JSON object")))? + .values_mut() + .map(move |v| { + let time_bin = get_nested(v, &[INFO, TIME_BIN]); + time_bin.map(|t| (t, v)) + }); + Ok(Box::new(bins)) +} + +fn field_error(fields: Vec<&str>) -> String { + let path = fields.join("."); + format!("expected path {path} missing from output object") +} + +fn nested_error(action: &str, fields: Vec<&str>, failed_key: &str, object: &Value) -> String { + let path = fields.join("."); + let keylist = object + .as_object() + .map(|o| o.keys().collect_vec()) + .unwrap_or_default(); + let keys = if keylist.len() > 5 { + let inner = keylist.iter().take(5).join(", "); + format!("[{inner}, ...]") + } else { + let inner = keylist.iter().join(", "); + format!("[{inner}]") + }; + format!( + "during {action}, expected path '{path}' missing key '{failed_key}' from JSON object available sibling keys: {keys}" + ) +} + +fn type_error(fields: Vec<&str>, expected_type: String) -> String { + let path = fields.join("."); + format!("expected value at path {path} to be {expected_type}") +} + +fn as_usize(value: &str) -> Result { + value.parse().map_err(|e| { + OutputPluginError::OutputPluginFailed(format!( + "unable to read oppportunity key '{value}' as a numeric value: {e}" + )) + }) +} diff --git a/rust/bambam-core/src/model/bambam_ops.rs b/rust/bambam-core/src/model/bambam_ops.rs new file mode 100644 index 00000000..f133a6cf --- /dev/null +++ b/rust/bambam-core/src/model/bambam_ops.rs @@ -0,0 +1,155 @@ +use crate::model::bambam_state; + +use super::{bambam_field, TimeBin}; +use geo::{line_measures::LengthMeasurable, Haversine, InterpolatableLine, LineString, Point}; +use routee_compass::{app::search::SearchAppResult, plugin::PluginError}; +use routee_compass_core::{ + algorithm::search::SearchTreeNode, + model::{ + label::Label, + state::{StateModel, StateModelError, StateVariable}, + unit::DistanceUnit, + }, +}; +use std::collections::HashMap; +use uom::{ + si::f64::{Length, Time}, + ConstZero, +}; +use wkt::ToWkt; + +pub type DestinationsIter<'a> = + Box> + 'a>; + +/// collects search tree branches that can be reached _as destinations_ +/// within the given time bin. +pub fn collect_destinations<'a>( + search_result: &'a SearchAppResult, + time_bin: Option<&'a TimeBin>, + state_model: &'a StateModel, +) -> DestinationsIter<'a> { + match search_result.trees.first() { + None => Box::new(std::iter::empty()), + Some(tree) => { + let tree_destinations = + tree.iter() + .filter_map(move |(label, branch)| match branch.incoming_edge() { + None => None, + Some(et) => { + let result_state = &et.result_state; + let within_bin = match &time_bin { + Some(bin) => bin.state_time_within_bin(result_state, state_model), + None => Ok(true), + }; + match within_bin { + Ok(true) => Some(Ok((label.clone(), branch))), + Ok(false) => None, + Err(e) => Some(Err(e)), + } + } + }); + + Box::new(tree_destinations) + } + } +} + +pub fn points_along_linestring( + linestring: &LineString, + stride: &Length, + _distance_unit: &DistanceUnit, +) -> Result>, String> { + let length: Length = + Length::new::(linestring.length(&Haversine) as f64); + + if &length < stride { + match (linestring.points().next(), linestring.points().next_back()) { + (Some(first), Some(last)) => Ok(vec![first, last]), + _ => Err(format!( + "invalid linestring, should have at least two points: {linestring:?}" + )), + } + } else { + // determine number of steps + let n_strides = (length / *stride).value.ceil() as u64; + let n_points = n_strides + 1; + + let length_meters = length.value; + + (0..=n_points) + .map(|point_index| { + let distance_to_point = stride.value * point_index as f64; + let fraction = (distance_to_point / length_meters) as f32; + let point = linestring + .point_at_ratio_from_start(&Haversine, fraction) + .ok_or_else(|| { + format!( + "unable to interpolate {}m/{}% into linestring with distance {}: {}", + distance_to_point, + (fraction * 10000.0).trunc() / 100.0, + length_meters, + linestring.to_wkt() + ) + })?; + Ok(point) + }) + .collect::, String>>() + } +} + +pub fn accumulate_global_opps( + opps: &[(usize, Vec)], + colnames: &[String], +) -> Result, PluginError> { + let mut result: HashMap = HashMap::new(); + for (_, row) in opps.iter() { + for (idx, value) in row.iter().enumerate() { + let colname = colnames.get(idx).ok_or_else(|| { + PluginError::InternalError( + "opportunity count row and activity types list do not match".to_string(), + ) + })?; + if let Some(val) = result.get_mut(colname) { + *val += value; + } else { + result.insert(colname.to_string(), *value); + } + } + } + Ok(result) +} + +/// helper that combines the arrival delay with the traversal time to produce +/// the time to reach this point and call it a destination. +pub fn get_reachability_time( + state: &[StateVariable], + state_model: &StateModel, +) -> Result { + let trip_time = state_model.get_time(state, bambam_state::TRIP_TIME)?; + let has_delay = state_model.contains_key(&bambam_state::TRIP_ARRIVAL_DELAY.to_string()); + let arrival_delay = if has_delay { + state_model.get_time(state, bambam_state::TRIP_ARRIVAL_DELAY)? + } else { + Time::ZERO + }; + Ok(trip_time + arrival_delay) +} + +/// steps through each bin's output section for mutable updates +pub fn iterate_bins<'a>( + output: &'a mut serde_json::Value, +) -> Result + 'a>, PluginError> { + let bins = output.get_mut(bambam_field::TIME_BINS).ok_or_else(|| { + PluginError::UnexpectedQueryStructure(format!( + "after running json structure plugin, cannot find key {}", + bambam_field::TIME_BINS + )) + })?; + let bins_map = bins.as_object_mut().ok_or_else(|| { + PluginError::UnexpectedQueryStructure(format!( + "after running json structure plugin, field {} was not a key/value map", + bambam_field::TIME_BINS + )) + })?; + Ok(Box::new(bins_map.iter_mut())) +} diff --git a/rust/bambam-core/src/model/bambam_state.rs b/rust/bambam-core/src/model/bambam_state.rs new file mode 100644 index 00000000..b2802358 --- /dev/null +++ b/rust/bambam-core/src/model/bambam_state.rs @@ -0,0 +1,28 @@ +//! state feature names assigned to the state model in bambam runs. also exports +//! the upstream-defined features from compass. + +/// time delays accumulated throughout the trip +pub const TRIP_ENROUTE_DELAY: &str = "trip_enroute_delay"; + +/// time delays on arriving at a destination, such as parking, which +/// are not incorporated into the search cost function. +pub const TRIP_ARRIVAL_DELAY: &str = "trip_arrival_delay"; + +/// during scheduled mode traversals, a record of the route used. +pub const ROUTE_ID: &str = "route_id"; + +/// a record of the total "switching mode" time. currently used in transit traversal to model waiting time +pub const TRANSIT_BOARDING_TIME: &str = "transit_boarding_time"; +/// a record of the total time sitting on transit during dwell in between edge traversals. +pub const DWELL_TIME: &str = "dwell_time"; + +/// used to penalize an edge. convention is to design this +/// as one of the vehicle cost rates, via a "raw" interpretation +/// (no cost conversion) and then to use "mul" (multiplicitive) +/// cost aggregation with this value and the total edge time. +/// when this value is 1.0, no penalty is applied. +/// if it is < 1, it reduces cost, and > 1, increases cost. +pub const COST_PENALTY_FACTOR: &str = "penalty_factor"; + +pub use routee_compass_core::model::traversal::default::fieldname::*; +pub use routee_compass_powertrain::model::fieldname::*; diff --git a/rust/bambam-core/src/model/mod.rs b/rust/bambam-core/src/model/mod.rs new file mode 100644 index 00000000..f9364751 --- /dev/null +++ b/rust/bambam-core/src/model/mod.rs @@ -0,0 +1,8 @@ +pub mod bambam_field; +pub mod bambam_state; +pub mod state; +mod time_bin; + +pub use time_bin::TimeBin; +pub mod bambam_ops; +pub mod output_plugin; diff --git a/rust/bambam-core/src/model/output_plugin/isochrone/isochrone_output_format.rs b/rust/bambam-core/src/model/output_plugin/isochrone/isochrone_output_format.rs new file mode 100644 index 00000000..fba8f2ce --- /dev/null +++ b/rust/bambam-core/src/model/output_plugin/isochrone/isochrone_output_format.rs @@ -0,0 +1,139 @@ +use geo::{Geometry, MapCoords, TryConvert}; +use geo_traits::to_geo::ToGeoGeometry; +use geojson; +use routee_compass::plugin::output::OutputPluginError; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use wkb; +use wkt::{ToWkt, TryFromWkt}; + +#[derive(Deserialize, Serialize, Clone, Debug)] +#[serde(rename_all = "snake_case")] +pub enum IsochroneOutputFormat { + Wkt, + Wkb, + GeoJson, +} + +impl IsochroneOutputFormat { + pub fn empty_geometry(&self) -> Result { + let empty: Geometry = Geometry::Polygon(geo::polygon![]); + self.serialize_geometry(&empty) + } + + pub fn deserialize_geometry(&self, value: &Value) -> Result, OutputPluginError> { + match self { + IsochroneOutputFormat::Wkt => { + let wkt = value.as_str().ok_or_else(|| { + OutputPluginError::OutputPluginFailed(format!( + "expected WKT string for geometry deserialization, found: {value:?}" + )) + })?; + let g = Geometry::try_from_wkt_str(wkt).map_err(|e| OutputPluginError::OutputPluginFailed(format!("failure deserializing WKT geometry from output row due to: {e} - WKT string: \"{wkt}\"")))?; + Ok(g) + } + IsochroneOutputFormat::Wkb => { + let wkb_str = value.as_str().ok_or_else(|| { + OutputPluginError::OutputPluginFailed(format!( + "expected WKB string for geometry deserialization, found: {value:?}" + )) + })?; + // Decode hex string to bytes + let wkb_bytes = hex::decode(wkb_str).map_err(|e| { + OutputPluginError::OutputPluginFailed(format!( + "failed to decode WKB hex string: {e} - WKB string: \"{wkb_str}\"" + )) + })?; + // Read geometry as f64, then convert to f32 + let geom_trait = wkb::reader::read_wkb(&wkb_bytes).map_err(|e| OutputPluginError::OutputPluginFailed(format!( + "failure deserializing WKB geometry from output row due to: {e} - WKB string: \"{wkb_str}\"" + )))?; + let geometry_f64 = geom_trait.to_geometry(); + let geometry_f32 = try_convert_f32(&geometry_f64)?; + Ok(geometry_f32) + } + IsochroneOutputFormat::GeoJson => { + let geojson_str = value.as_str().ok_or_else(|| { + OutputPluginError::OutputPluginFailed(format!( + "expected string for geometry deserialization, found: {value:?}" + )) + })?; + let geojson_obj = geojson_str.parse::().map_err(|e| { + OutputPluginError::OutputPluginFailed(format!( + "failure parsing GeoJSON from geometry string due to: {e}, found: {value:?}" + )) + })?; + let geometry = geo_types::Geometry::::try_from(geojson_obj).map_err(|e| { + OutputPluginError::OutputPluginFailed(format!( + "failure converting GeoJSON to Geometry due to: {e}" + )) + })?; + Ok(geometry) + } + } + } + + pub fn serialize_geometry( + &self, + geometry: &Geometry, + ) -> Result { + match self { + IsochroneOutputFormat::Wkt => Ok(geometry.wkt_string()), + IsochroneOutputFormat::Wkb => { + let mut out_bytes = vec![]; + let geom: Geometry = geometry.try_convert().map_err(|e| { + OutputPluginError::OutputPluginFailed(format!( + "unable to convert geometry from f32 to f64: {e}" + )) + })?; + let write_options = wkb::writer::WriteOptions { + endianness: wkb::Endianness::BigEndian, + }; + wkb::writer::write_geometry(&mut out_bytes, &geom, &write_options).map_err( + |e| { + OutputPluginError::OutputPluginFailed(format!( + "failed to write geometry as WKB: {e}" + )) + }, + )?; + + Ok(out_bytes + .iter() + .map(|b| format!("{b:02X?}")) + .collect::>() + .join("")) + } + IsochroneOutputFormat::GeoJson => { + let geometry = geojson::Geometry::from(geometry); + let feature = geojson::Feature { + bbox: None, + geometry: Some(geometry), + id: None, + properties: None, + foreign_members: None, + }; + let result = serde_json::to_value(feature)?; + Ok(result.to_string()) + } + } + } +} + +fn try_convert_f32(g: &Geometry) -> Result, OutputPluginError> { + let (min, max) = (f32::MIN as f64, f32::MAX as f64); + g.try_map_coords(|geo::Coord { x, y }| { + if x < min || max < x { + Err(OutputPluginError::OutputPluginFailed(format!( + "could not express x value '{x}' as f32, exceeds range of possible values [{min}, {max}]" + ))) + } else if y < min || max < y { + Err(OutputPluginError::OutputPluginFailed(format!( + "could not express y value '{y}' as f32, exceeds range of possible values [{min}, {max}]" + ))) + } else { + let x32 = x as f32; + let y32 = y as f32; + Ok(geo::Coord { x: x32, y: y32 }) + } + }) +} diff --git a/rust/bambam-core/src/model/output_plugin/isochrone/mod.rs b/rust/bambam-core/src/model/output_plugin/isochrone/mod.rs new file mode 100644 index 00000000..2bd3c59a --- /dev/null +++ b/rust/bambam-core/src/model/output_plugin/isochrone/mod.rs @@ -0,0 +1,3 @@ +mod isochrone_output_format; + +pub use isochrone_output_format::IsochroneOutputFormat; diff --git a/rust/bambam-core/src/model/output_plugin/mod.rs b/rust/bambam-core/src/model/output_plugin/mod.rs new file mode 100644 index 00000000..8e92795f --- /dev/null +++ b/rust/bambam-core/src/model/output_plugin/mod.rs @@ -0,0 +1,2 @@ +pub mod isochrone; +pub mod opportunity; diff --git a/rust/bambam-core/src/model/output_plugin/opportunity/destination_opportunity.rs b/rust/bambam-core/src/model/output_plugin/opportunity/destination_opportunity.rs new file mode 100644 index 00000000..bf2273b7 --- /dev/null +++ b/rust/bambam-core/src/model/output_plugin/opportunity/destination_opportunity.rs @@ -0,0 +1,11 @@ +use routee_compass_core::model::state::StateVariable; +use serde::{Deserialize, Serialize}; + +/// activity counts and vehicle state observed when reaching a destination +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DestinationOpportunity { + /// opportunity counts for this location + pub counts: Vec, + /// vehicle state when this location was reached + pub state: Vec, +} diff --git a/rust/bambam-core/src/model/output_plugin/opportunity/mod.rs b/rust/bambam-core/src/model/output_plugin/opportunity/mod.rs new file mode 100644 index 00000000..36cc5d49 --- /dev/null +++ b/rust/bambam-core/src/model/output_plugin/opportunity/mod.rs @@ -0,0 +1,11 @@ +mod destination_opportunity; +mod opportunity_format; +mod opportunity_orientation; +mod opportunity_record; +mod opportunity_row_id; + +pub use destination_opportunity::DestinationOpportunity; +pub use opportunity_format::OpportunityFormat; +pub use opportunity_orientation::OpportunityOrientation; +pub use opportunity_record::OpportunityRecord; +pub use opportunity_row_id::OpportunityRowId; diff --git a/rust/bambam-core/src/model/output_plugin/opportunity/opportunity_format.rs b/rust/bambam-core/src/model/output_plugin/opportunity/opportunity_format.rs new file mode 100644 index 00000000..520cac2d --- /dev/null +++ b/rust/bambam-core/src/model/output_plugin/opportunity/opportunity_format.rs @@ -0,0 +1,82 @@ +use crate::model::bambam_field; +use crate::model::output_plugin::opportunity::{DestinationOpportunity, OpportunityRowId}; +use routee_compass::plugin::output::OutputPluginError; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +/// Sets how opportunities are tagged to a response row as either aggregate or disaggregate. +#[derive(Deserialize, Serialize, Clone, Debug, Copy)] +#[serde(rename_all = "snake_case")] +pub enum OpportunityFormat { + /// write opportunities as a JSON object with keys as activity types, values + /// as activity counts summed across the entire scenario + Aggregate, + /// write opportunities as a JSON object with keys as destination id, values + /// as opportunity count objects + Disaggregate, +} + +impl std::fmt::Display for OpportunityFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let key = match self { + OpportunityFormat::Aggregate => bambam_field::OPP_FMT_AGGREGATE, + OpportunityFormat::Disaggregate => bambam_field::OPP_FMT_DISAGGREGATE, + }; + write!(f, "{key}") + } +} + +impl OpportunityFormat { + /// serializes the provided opportunities into JSON based on the chosen format. + /// + /// # Arguments + /// + /// * `opportunities` - the output of the [`super::opportunity_model::OpportunityModel`] + /// * `activity_types` - the names of each activity in each opportunity row + /// + /// # Returns + /// + /// A JSON object representing these opportunities + pub fn serialize_opportunities( + &self, + opportunities: &Vec<(OpportunityRowId, DestinationOpportunity)>, + activity_types: &Vec, + ) -> Result { + match self { + OpportunityFormat::Aggregate => { + // accumulate activity count totals + let mut acc: Vec = vec![0.0; activity_types.len()]; + for (_id, row) in opportunities { + for (idx, row_value) in row.counts.iter().enumerate() { + acc[idx] += *row_value; + } + } + // create output mapping, a Map + let mut result = serde_json::Map::new(); + for (cnt, act) in acc.iter().zip(activity_types) { + result.insert(act.to_owned(), json![cnt]); + } + Ok(result.into()) + } + OpportunityFormat::Disaggregate => { + // serialize all rows as a mapping from id to opportunity counts object + let mut result = serde_json::Map::new(); + for (id, row) in opportunities { + let mut row_obj = serde_json::Map::new(); + for (idx, row_value) in row.counts.iter().enumerate() { + let activity_type = + activity_types + .get(idx) + .cloned() + .ok_or_else(|| OutputPluginError::InternalError(format!( + "index {idx} invalid for opportunity vector {row_value:?}, should match cardinality of activity types dataset {activity_types:?}" + )))?; + row_obj.insert(activity_type, json!(row_value)); + } + result.insert(id.to_string(), row_obj.into()); + } + Ok(result.into()) + } + } + } +} diff --git a/rust/bambam-core/src/model/output_plugin/opportunity/opportunity_orientation.rs b/rust/bambam-core/src/model/output_plugin/opportunity/opportunity_orientation.rs new file mode 100644 index 00000000..7ebf8104 --- /dev/null +++ b/rust/bambam-core/src/model/output_plugin/opportunity/opportunity_orientation.rs @@ -0,0 +1,19 @@ +use serde::{Deserialize, Serialize}; + +/// An enumeration representing how activities are tagged to the graph. +#[derive(Deserialize, Serialize, Clone, Copy, Debug, Hash, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum OpportunityOrientation { + OriginVertexOriented, + DestinationVertexOriented, + EdgeOriented, +} + +impl std::fmt::Display for OpportunityOrientation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = serde_json::to_string(self) + .unwrap_or(String::from("")) + .replace('\"', ""); + write!(f, "{s}") + } +} diff --git a/rust/bambam-core/src/model/output_plugin/opportunity/opportunity_record.rs b/rust/bambam-core/src/model/output_plugin/opportunity/opportunity_record.rs new file mode 100644 index 00000000..e0e7ca70 --- /dev/null +++ b/rust/bambam-core/src/model/output_plugin/opportunity/opportunity_record.rs @@ -0,0 +1,89 @@ +use crate::model::{ + bambam_field, bambam_state, output_plugin::opportunity::OpportunityOrientation, TimeBin, +}; +use routee_compass::plugin::output::OutputPluginError; +use routee_compass_core::model::state::{StateModel, StateVariable}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use uom::si::f64::Time; + +/// properties of accessing some activity type from a grid cell origin location. comes in two flavors: +/// +/// 1. Aggregate - zonal/isochrone access to a type of activity +/// 2. Disaggregate - access data for exactly one opportunity +/// +/// the properties of this opportunity access influence the modal intensities, modal coefficients, +/// and activity frequencies selected for computing an access metric. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum OpportunityRecord { + Aggregate { + activity_type: String, + geometry: geo::Geometry, + time_bin: TimeBin, + count: f64, + }, + Disaggregate { + id: String, + activity_type: String, + opportunity_orientation: OpportunityOrientation, + geometry: geo::Geometry, + state: Vec, + }, +} + +impl OpportunityRecord { + pub fn get_json_path(&self) -> Vec { + match self { + OpportunityRecord::Aggregate { time_bin, .. } => { + vec![bambam_field::TIME_BINS.to_string(), time_bin.key()] + } + OpportunityRecord::Disaggregate { id, .. } => { + vec![bambam_field::OPPORTUNITIES.to_string(), id.to_string()] + } + } + } + + pub fn get_time(&self, state_model: Arc) -> Result { + match self { + Self::Disaggregate { state, .. } => { + // time comes from the trip travel time taken to reach this point + state_model.get_time(state, bambam_state::TRIP_TIME) + .map_err(|e| OutputPluginError::OutputPluginFailed(format!("with disaggregate opportunity record, could not find trip time due to: {e}"))) + } + Self::Aggregate { time_bin, .. } => { + // time comes from the isochrone bin + Ok(time_bin.max_time()) + } + } + } + pub fn get_activity_type(&self) -> &str { + match self { + Self::Aggregate { activity_type, .. } => activity_type, + Self::Disaggregate { activity_type, .. } => activity_type, + } + } + + pub fn get_count(&self) -> f64 { + match self { + OpportunityRecord::Aggregate { count, .. } => *count, + OpportunityRecord::Disaggregate { .. } => 1.0, + } + } + + pub fn get_geometry(&self) -> &geo::Geometry { + match self { + Self::Aggregate { geometry, .. } => geometry, + Self::Disaggregate { geometry, .. } => geometry, + } + } + + pub fn get_opportunity_orientation(&self) -> Option<&OpportunityOrientation> { + match self { + Self::Aggregate { .. } => None, + Self::Disaggregate { + opportunity_orientation, + .. + } => Some(opportunity_orientation), + } + } +} diff --git a/rust/bambam-core/src/model/output_plugin/opportunity/opportunity_row_id.rs b/rust/bambam-core/src/model/output_plugin/opportunity/opportunity_row_id.rs new file mode 100644 index 00000000..b55bdd3c --- /dev/null +++ b/rust/bambam-core/src/model/output_plugin/opportunity/opportunity_row_id.rs @@ -0,0 +1,144 @@ +use std::sync::Arc; + +use geo::{Centroid, Convert}; +use routee_compass::plugin::output::OutputPluginError; +use routee_compass_core::{ + algorithm::search::{SearchInstance, SearchTreeNode}, + model::{ + label::Label, + map::MapModel, + network::{EdgeId, EdgeListId, Graph}, + }, +}; +use rstar::{RTreeObject, AABB}; +use serde::Serialize; +use wkt::ToWkt; + +use crate::model::output_plugin::opportunity::opportunity_orientation::OpportunityOrientation; + +// identifier in the graph tagging where an opportunity was found +#[derive(Serialize, Clone, PartialEq, Eq, Hash, Debug)] +pub enum OpportunityRowId { + OriginVertex(Label), + DestinationVertex(Label), + Edge(EdgeListId, EdgeId), +} + +impl std::fmt::Display for OpportunityRowId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + OpportunityRowId::OriginVertex(label) => label.to_string(), + OpportunityRowId::DestinationVertex(label) => label.to_string(), + OpportunityRowId::Edge(edge_list_id, edge_id) => format!("{edge_list_id}-{edge_id}"), + }; + write!(f, "{s}") + } +} + +impl OpportunityRowId { + /// create a new opportunity vector identifier based on the table orientation which denotes where opportunities are stored + pub fn new( + branch_label: &Label, + branch: &SearchTreeNode, + format: &OpportunityOrientation, + ) -> Result { + use OpportunityOrientation as O; + match format { + // stored at the origin of the edge, corresponding with the branch origin id + O::OriginVertexOriented => Ok(Self::OriginVertex(branch_label.clone())), + // stored at the destination of the edge at the branch's terminal vertex id + O::DestinationVertexOriented => Ok(Self::DestinationVertex(branch.label().clone())), + // stored on the edge itself + O::EdgeOriented => { + match branch.incoming_edge() { + None => Err(OutputPluginError::InternalError(String::from("while building EdgeOriented OpportunityRowId, was passed tree root, which has no corresponding edge"))), + Some(et) => Ok(Self::Edge(et.edge_list_id, et.edge_id)), + } + } + } + } + + /// helper to get the POINT geometry associated with this index, if defined + pub fn get_vertex_point( + &self, + graph: Arc, + ) -> Result, OutputPluginError> { + let vertex_id = match self { + OpportunityRowId::OriginVertex(label) => Ok(label.vertex_id()), + OpportunityRowId::DestinationVertex(label) => Ok(label.vertex_id()), + OpportunityRowId::Edge(..) => Err(OutputPluginError::InternalError(String::from( + "cannot get vertex point for edge", + ))), + }?; + + let vertex = graph.get_vertex(vertex_id).map_err(|_e| { + OutputPluginError::OutputPluginFailed(format!("unknown vertex id '{vertex_id}'")) + })?; + let point = geo::Point::new(vertex.x(), vertex.y()); + Ok(point) + } + + /// helper to get the LINESTRING geometry associated with this index, if defined + pub fn get_edge_linestring( + &self, + map_model: Arc, + ) -> Result, OutputPluginError> { + let (edge_list_id, edge_id) = match self { + OpportunityRowId::Edge(edge_list_id, edge_id) => Ok((edge_list_id, edge_id)), + _ => Err(OutputPluginError::InternalError(String::from( + "cannot get edge linestring for vertex", + ))), + }?; + map_model + .get_linestring(edge_list_id, edge_id) + .cloned() + .map_err(|_e| { + OutputPluginError::OutputPluginFailed(format!("unknown edge id '{edge_id}'")) + }) + } + + pub fn get_envelope_f64( + &self, + si: &SearchInstance, + ) -> Result, OutputPluginError> { + match self { + OpportunityRowId::OriginVertex(_) => { + let point = self.get_vertex_point(si.graph.clone())?.convert(); + Ok(point.envelope()) + } + OpportunityRowId::DestinationVertex(_) => { + let point = self.get_vertex_point(si.graph.clone())?.convert(); + Ok(point.envelope()) + } + OpportunityRowId::Edge(..) => { + let linestring = self.get_edge_linestring(si.map_model.clone())?.convert(); + Ok(linestring.envelope()) + } + } + } + + pub fn get_centroid_f64(&self, si: &SearchInstance) -> Result { + match self { + OpportunityRowId::OriginVertex(_) => { + let point = self.get_vertex_point(si.graph.clone())?.convert(); + let centroid = point.centroid(); + Ok(centroid) + } + OpportunityRowId::DestinationVertex(_) => { + let point = self.get_vertex_point(si.graph.clone())?.convert(); + let centroid = point.centroid(); + Ok(centroid) + } + OpportunityRowId::Edge(..) => { + let linestring = self.get_edge_linestring(si.map_model.clone())?.convert(); + let centroid = linestring.centroid().ok_or_else(|| { + OutputPluginError::OutputPluginFailed(format!( + "could not get centroid of LINESTRING {}", + linestring.to_wkt() + )) + })?; + Ok(centroid) + } + } + } +} diff --git a/rust/bambam-core/src/model/state/fieldname.rs b/rust/bambam-core/src/model/state/fieldname.rs new file mode 100644 index 00000000..b3b09e43 --- /dev/null +++ b/rust/bambam-core/src/model/state/fieldname.rs @@ -0,0 +1,51 @@ +//! field names for state variables related to multimodal routing. accumulators +//! for trip legs and mode-specific metrics need to be stored in fields with +//! normalized fieldname structures that are programmatically generated. + +use crate::model::state::LegIdx; +pub use routee_compass_core::model::traversal::default::fieldname::*; + +/// the id of the active leg. zero if no leg is active. 1+ are leg identifiers. +pub const ACTIVE_LEG: &str = "active_leg"; + +/// the state variable name containing the mode for a given leg id +pub fn leg_mode_fieldname(leg_idx: LegIdx) -> String { + leg_fieldname(leg_idx, "mode") +} + +/// the state variable name containing the distance for a given leg id +pub fn leg_distance_fieldname(leg_idx: LegIdx) -> String { + leg_fieldname(leg_idx, "distance") +} + +/// the state variable name containing the time for a given leg id +pub fn leg_time_fieldname(leg_idx: LegIdx) -> String { + leg_fieldname(leg_idx, "time") +} + +/// the state variable name containing the route id for a given leg id +pub fn leg_route_id_fieldname(leg_idx: LegIdx) -> String { + leg_fieldname(leg_idx, "route_id") +} + +/// the state variable name containing the distance for a given mode +pub fn mode_distance_fieldname(mode: &str) -> String { + mode_fieldname(mode, "distance") +} + +/// the state variable name containing the time for a given mode +pub fn mode_time_fieldname(mode: &str) -> String { + mode_fieldname(mode, "time") +} + +/// helper function for creating normalized and enumerated fieldnames +/// for fields associated with a trip leg. +fn leg_fieldname(leg_idx: LegIdx, field: &str) -> String { + format!("leg_{leg_idx}_{field}") +} + +/// helper function for creating normalized fieldnames for fields +/// accumulating metrics for a given travel mode. +fn mode_fieldname(mode: &str, field: &str) -> String { + format!("mode_{mode}_{field}") +} diff --git a/rust/bambam-core/src/model/state/mod.rs b/rust/bambam-core/src/model/state/mod.rs new file mode 100644 index 00000000..1129f8d1 --- /dev/null +++ b/rust/bambam-core/src/model/state/mod.rs @@ -0,0 +1,13 @@ +pub mod fieldname; +mod multimodal_mapping; +pub mod multimodal_state_ops; +pub mod variable; + +pub use multimodal_mapping::MultimodalMapping; +pub use multimodal_mapping::MultimodalStateMapping; +/// trip legs are enumerated starting from 0 to support zero-based indexing arithmetic. +pub type LegIdx = u64; + +/// value for entries in a [`use routee_compass_core::model::label::Label`] denoting +/// no modes are assigned for a given trip leg. +pub const EMPTY_TRIP_LABEL: usize = usize::MAX; diff --git a/rust/bambam-core/src/model/state/multimodal_mapping.rs b/rust/bambam-core/src/model/state/multimodal_mapping.rs new file mode 100644 index 00000000..c10134cf --- /dev/null +++ b/rust/bambam-core/src/model/state/multimodal_mapping.rs @@ -0,0 +1,169 @@ +use routee_compass_core::{ + model::state::StateModelError, + util::fs::{read_decoders, read_utils}, +}; +use std::{collections::HashMap, fmt::Debug, path::Path}; + +/// stores the bijection from categorical name to an enumeration label compatible +/// with Compass LabelModels and Custom StateModels. +/// +/// ## Types +/// T: the type `T` is some hashable, categorical type. consider String as a starting point. +/// +/// U: the type `U` is some hashable, numeric typw which can be built From so that it +/// can be used to index a Vec. +/// +#[derive(Clone, Debug)] +pub struct MultimodalMapping { + cat_to_label: HashMap, + label_to_cat: Vec, +} + +/// a common type of multimodal mapping which maps strings to i64 values. +/// categories begin from zero. negative values denote an empty class label (None case). +pub type MultimodalStateMapping = MultimodalMapping; + +/// A trait for types that can be used as categorical identifiers +#[allow(dead_code)] +trait Categorical: Eq + std::hash::Hash + Clone + Debug {} + +/// A trait for types that can be used as indices in the multimodal mapping +#[allow(dead_code)] +trait IndexType: + Eq + std::hash::Hash + Clone + Copy + TryFrom + TryInto + PartialOrd + Debug +{ +} + +// Blanket implementations for common types +impl Categorical for T where T: Eq + std::hash::Hash + Clone + Debug {} +impl IndexType for U where + U: Eq + std::hash::Hash + Clone + Copy + TryFrom + TryInto + PartialOrd + Debug +{ +} + +impl MultimodalStateMapping { + pub fn from_enumerated_category_file(filepath: &Path) -> Result { + let contents = read_utils::read_raw_file(filepath, read_decoders::string, None, None) + .map_err(|e| { + StateModelError::BuildError(format!( + "failure reading enumerated category mapping from {}: {e}", + filepath.to_string_lossy() + )) + })?; + MultimodalMapping::new(&contents) + } +} + +impl MultimodalMapping +where + T: Eq + std::hash::Hash + Clone + Debug, + U: Eq + std::hash::Hash + Clone + Copy + TryFrom + TryInto + PartialOrd + Debug, +{ + /// create an empty mapping + pub fn empty() -> Self { + MultimodalMapping { + cat_to_label: HashMap::new(), + label_to_cat: Vec::new(), + } + } + + /// create a new mapping from a list of categoricals. for each categorical value in categoricals, + /// it will be assigned a label integer using the categorical's row index. + pub fn new(categoricals: &[T]) -> Result { + let label_to_cat = categoricals.to_vec(); + let cat_to_label = label_to_cat + .iter() + .enumerate() + .map(|(idx, t)| { + let u = try_into_u(idx).map_err(|e| { + StateModelError::BuildError(format!("for mapping value {t:?}, {e}")) + })?; + Ok((t.clone(), u)) + }) + .collect::, StateModelError>>()?; + Ok(Self { + cat_to_label, + label_to_cat, + }) + } + + /// get the list (in enumeration order) of categories + pub fn get_categories(&self) -> &[T] { + &self.label_to_cat + } + + /// count the number of mapped categories + pub fn n_categories(&self) -> usize { + self.label_to_cat.len() + } + + /// append another categorical to the mapping, returning the new categorical's label id. + /// if the categorical is already stored in the mapping, we return the existing label and + /// no insert occurs. + pub fn insert(&mut self, categorical: T) -> Result { + if let Some(&label) = self.cat_to_label.get(&categorical) { + return Ok(label); + } + + let next_label = try_into_u(self.label_to_cat.len())?; + self.cat_to_label.insert(categorical.clone(), next_label); + self.label_to_cat.push(categorical); + Ok(next_label) + } + + /// perform a categorical->label lookup. + pub fn get_label(&self, categorical: &Q) -> Option<&U> + where + T: std::borrow::Borrow, + Q: std::hash::Hash + Eq + ?Sized, + { + self.cat_to_label.get(categorical) + } + + /// perform a label->categorical lookup. + pub fn get_categorical(&self, label: U) -> Result, StateModelError> { + if is_empty(label)? { + Ok(None) + } else { + let idx: usize = try_into_usize(label)?; + let result = self.label_to_cat.get(idx); + Ok(result) + } + } +} + +fn is_empty(u: U) -> Result +where + U: Eq + std::hash::Hash + Clone + Copy + TryFrom + TryInto + PartialOrd + Debug, +{ + let zero = U::try_from(0).map_err(|_| { + StateModelError::BuildError("could not create zero value for type".to_string()) + })?; + Ok(u < zero) +} + +/// helper function to convert a U into a usize with error message result. +/// handles the case where u is negative and treats that as an empty value. +fn try_into_usize(u: U) -> Result +where + U: Eq + std::hash::Hash + Clone + Copy + TryFrom + TryInto + Debug, +{ + let as_usize = u.try_into().map_err(|_e| { + StateModelError::BuildError(format!( + "could not convert Index {u:?} to a usize type, should implement TryInto" + )) + })?; + Ok(as_usize) +} + +/// helper function to convert a usize into a U with error message result +fn try_into_u(idx: usize) -> Result +where + U: Eq + std::hash::Hash + Clone + Copy + TryFrom + TryInto + Debug, +{ + idx.try_into().map_err(|_e| { + StateModelError::BuildError(format!( + "could not convert index {idx} to a Index type, should implement TryFrom" + )) + }) +} diff --git a/rust/bambam-core/src/model/state/multimodal_state_ops.rs b/rust/bambam-core/src/model/state/multimodal_state_ops.rs new file mode 100644 index 00000000..d2dd57cf --- /dev/null +++ b/rust/bambam-core/src/model/state/multimodal_state_ops.rs @@ -0,0 +1,296 @@ +use crate::model::state::{LegIdx, MultimodalStateMapping}; +use routee_compass_core::model::state::{StateModel, StateModelError, StateVariable}; +use serde_json::json; +use uom::si::f64::{Length, Time}; + +use super::fieldname; + +/// inspect the current active leg for a trip +pub fn get_active_leg_idx( + state: &[StateVariable], + state_model: &StateModel, +) -> Result, StateModelError> { + let leg_i64 = state_model.get_custom_i64(state, fieldname::ACTIVE_LEG)?; + if leg_i64 < 0 { + Ok(None) + } else { + let leg_u64 = leg_i64.try_into().map_err(|_e| { + StateModelError::RuntimeError(format!( + "internal error: while getting active trip leg, unable to parse {leg_i64} as a u64" + )) + })?; + Ok(Some(leg_u64)) + } +} + +/// inspect the current active leg mode for a trip. if the trip +/// has no leg, returns None. +pub fn get_active_leg_mode<'a>( + state: &[StateVariable], + state_model: &StateModel, + max_trip_legs: LegIdx, + mode_to_state: &'a MultimodalStateMapping, +) -> Result, StateModelError> { + match get_active_leg_idx(state, state_model)? { + None => Ok(None), + Some(leg_idx) => { + let mode = + get_existing_leg_mode(state, leg_idx, state_model, max_trip_legs, mode_to_state)?; + Ok(Some(mode)) + } + } +} + +/// use the active leg index to count the number of trip legs in this state vector +pub fn get_n_legs( + state: &[StateVariable], + state_model: &StateModel, +) -> Result { + match get_active_leg_idx(state, state_model)? { + None => Ok(0), + Some(leg_idx) => { + let count: usize = (leg_idx + 1).try_into().map_err(|_e| { + StateModelError::RuntimeError(format!( + "internal error: unable to convert leg index {leg_idx} from u64 into usize" + )) + })?; + Ok(count) + } + } +} + +/// report if any trip data has been recorded for the given trip leg. +/// this uses the fact that any trip leg must have a leg mode, and leg modes +/// are stored with non-negative integer values, negative denotes "empty". +/// see [`super::state_variable`] for the leg mode variable configuration. +pub fn contains_leg( + state: &[StateVariable], + leg_idx: LegIdx, + state_model: &StateModel, +) -> Result { + let name = fieldname::leg_mode_fieldname(leg_idx); + let label = state_model.get_custom_i64(state, &name)?; + Ok(label >= 0) +} + +/// get the travel mode for a leg. +pub fn get_leg_mode_label( + state: &[StateVariable], + leg_idx: LegIdx, + state_model: &StateModel, + max_trip_legs: LegIdx, +) -> Result, StateModelError> { + validate_leg_idx(leg_idx, max_trip_legs)?; + let name = fieldname::leg_mode_fieldname(leg_idx); + let label = state_model.get_custom_i64(state, &name)?; + if label < 0 { + Ok(None) + } else { + Ok(Some(label)) + } +} + +/// get the travel mode for a leg. assumed that the leg mode exists, +/// if the mode is not set, it is an error. +pub fn get_existing_leg_mode<'a>( + state: &[StateVariable], + leg_idx: LegIdx, + state_model: &StateModel, + max_trip_legs: LegIdx, + mode_to_state: &'a MultimodalStateMapping, +) -> Result<&'a str, StateModelError> { + let label_opt = get_leg_mode_label(state, leg_idx, state_model, max_trip_legs)?; + match label_opt { + None => Err(StateModelError::RuntimeError(format!( + "Internal Error: get_leg_mode called on leg idx {leg_idx} but mode label is not set" + ))), + Some(label) => mode_to_state + .get_categorical(label)? + .ok_or_else(|| { + StateModelError::RuntimeError(format!( + "internal error, leg {leg_idx} has invalid mode label {label}" + )) + }) + .map(|s| s.as_str()), + } +} + +pub fn get_leg_distance( + state: &[StateVariable], + leg_idx: LegIdx, + state_model: &StateModel, +) -> Result { + let name = fieldname::leg_distance_fieldname(leg_idx); + state_model.get_distance(state, &name) +} + +pub fn get_leg_time( + state: &[StateVariable], + leg_idx: LegIdx, + state_model: &StateModel, +) -> Result { + let name = fieldname::leg_time_fieldname(leg_idx); + state_model.get_time(state, &name) +} + +pub fn get_leg_route_id<'a>( + state: &[StateVariable], + leg_idx: LegIdx, + state_model: &StateModel, + route_id_mapping: &'a MultimodalStateMapping, +) -> Result, StateModelError> { + let name = fieldname::leg_route_id_fieldname(leg_idx); + let route_id_label = state_model.get_custom_i64(state, &name)?; + let route_id = route_id_mapping.get_categorical(route_id_label)?; + Ok(route_id) +} + +pub fn get_mode_distance( + state: &[StateVariable], + mode: &str, + state_model: &StateModel, +) -> Result { + let name = fieldname::mode_distance_fieldname(mode); + state_model.get_distance(state, &name) +} + +pub fn get_mode_time( + state: &[StateVariable], + mode: &str, + state_model: &StateModel, +) -> Result { + let name = fieldname::mode_time_fieldname(mode); + state_model.get_time(state, &name) +} + +/// retrieves the sequence of mode labels stored on this state. stops when an unset +/// mode label is encountered. +pub fn get_mode_label_sequence( + state: &[StateVariable], + state_model: &StateModel, + max_trip_legs: LegIdx, +) -> Result, StateModelError> { + let mut labels: Vec = vec![]; + + for leg_idx in 0..max_trip_legs { + let mode_label_opt = get_leg_mode_label(state, leg_idx, state_model, max_trip_legs)?; + match mode_label_opt { + None => break, + Some(mode_label) => { + labels.push(mode_label); + } + } + } + + Ok(labels) +} + +/// retrieves the sequence of modes stored on this state. stops when an unset +/// mode label is encountered. +pub fn get_mode_sequence( + state: &[StateVariable], + state_model: &StateModel, + max_trip_legs: LegIdx, + mode_to_state: &MultimodalStateMapping, +) -> Result, StateModelError> { + let mut modes: Vec = vec![]; + let mut leg_idx = 0; + while contains_leg(state, leg_idx, state_model)? { + let mode = + get_existing_leg_mode(state, leg_idx, state_model, max_trip_legs, mode_to_state)?; + modes.push(mode.to_string()); + leg_idx += 1; + } + Ok(modes) +} + +/// increments the value at [`fieldname::ACTIVE_LEG`]. +/// when ACTIVE_LEG is negative (no active leg), it becomes zero. +/// when it is a number in [0, max_legs-1), it is incremented by one. +/// returns the new index value. +pub fn increment_active_leg_idx( + state: &mut [StateVariable], + state_model: &StateModel, + max_trip_legs: LegIdx, +) -> Result { + // get the index of the next leg + let next_leg_idx_u64 = match get_active_leg_idx(state, state_model)? { + Some(leg_idx) => { + let next = leg_idx + 1; + validate_leg_idx(next, max_trip_legs)?; + next + } + None => 0, + }; + // as an i64, to match the storage format + let next_leg_idx: i64 = next_leg_idx_u64.try_into().map_err(|_e| { + StateModelError::RuntimeError(format!( + "internal error: while getting active trip leg, unable to parse {next_leg_idx_u64} as a i64" + )) + })?; + + // increment the value in the state vector + state_model.set_custom_i64(state, fieldname::ACTIVE_LEG, &next_leg_idx)?; + Ok(next_leg_idx_u64) +} + +/// sets the mode value for the given leg. performs mapping from Mode -> i64 which is +/// the storage type for Mode in the state vector. +pub fn set_leg_mode( + state: &mut [StateVariable], + leg_idx: LegIdx, + mode: &str, + state_model: &StateModel, + mode_to_state: &MultimodalStateMapping, +) -> Result<(), StateModelError> { + let mode_label = mode_to_state.get_label(mode).ok_or_else(|| { + StateModelError::RuntimeError(format!("mode mapping has no entry for '{mode}' mode")) + })?; + let name = fieldname::leg_mode_fieldname(leg_idx); + state_model.set_custom_i64(state, &name, mode_label) +} + +/// sets the mode value for the given leg. performs mapping from Mode -> i64 which is +/// the storage type for Mode in the state vector. +pub fn set_leg_route_id( + state: &mut [StateVariable], + leg_idx: LegIdx, + route_id: &str, + state_model: &StateModel, + route_id_to_state: &MultimodalStateMapping, +) -> Result<(), StateModelError> { + let route_id_label = route_id_to_state.get_label(route_id).ok_or_else(|| { + StateModelError::RuntimeError(format!( + "route_id mapping has no entry for '{route_id}' route id" + )) + })?; + let name = fieldname::leg_route_id_fieldname(leg_idx); + state_model.set_custom_i64(state, &name, route_id_label) +} + +/// validates leg_idx values, which must be in range [0, max_trip_legs) +pub fn validate_leg_idx(leg_idx: LegIdx, max_trip_legs: LegIdx) -> Result<(), StateModelError> { + if leg_idx >= max_trip_legs { + Err(StateModelError::RuntimeError(format!( + "invalid leg id {leg_idx} >= max leg id {max_trip_legs}" + ))) + } else { + Ok(()) + } +} + +/// helper function for creating a descriptive error when attempting to apply +/// the multimodal traversal model on a state that has not activated it's first trip leg. +pub fn error_inactive_state_traversal( + state: &[StateVariable], + state_model: &StateModel, +) -> StateModelError { + let next_json = state_model.serialize_state(state, false).unwrap_or_else( + |e| json!({"message": "unable to serialize state!", "error": format!("{e}")}), + ); + let next_string = serde_json::to_string_pretty(&next_json) + .unwrap_or_else(|_e| String::from("")); + StateModelError::RuntimeError(format!( + "attempting multimodal traversal with state that has no active leg: {next_string}" + )) +} diff --git a/rust/bambam-core/src/model/state/variable.rs b/rust/bambam-core/src/model/state/variable.rs new file mode 100644 index 00000000..ffcd4b48 --- /dev/null +++ b/rust/bambam-core/src/model/state/variable.rs @@ -0,0 +1,81 @@ +//! constructors for [`StateVariableConfig`] instances in multimodal routing. +use crate::model::state::{fieldname, LegIdx}; +use routee_compass_core::model::{ + state::{CustomVariableConfig, InputFeature, StateVariableConfig}, + unit::{DistanceUnit, TimeUnit}, +}; +use uom::{ + si::f64::{Length, Time}, + ConstZero, +}; + +/// config value representing an empty LegIdx, Mode, or RouteId. +pub const EMPTY: CustomVariableConfig = CustomVariableConfig::SignedInteger { initial: -1 }; + +pub fn active_leg_input_feature() -> InputFeature { + InputFeature::Custom { + name: "active_leg".to_string(), + unit: "signed_integer".to_string(), + } +} + +pub fn active_leg_variable_config() -> StateVariableConfig { + StateVariableConfig::Custom { + custom_type: "ActiveLeg".to_string(), + value: EMPTY, + accumulator: true, + } +} + +pub fn leg_mode_input_feature(leg_idx: LegIdx) -> InputFeature { + InputFeature::Custom { + name: fieldname::leg_mode_fieldname(leg_idx), + unit: "signed_integer".to_string(), + } +} + +/// creates configuration for mode state variables +pub fn leg_mode_variable_config() -> StateVariableConfig { + StateVariableConfig::Custom { + custom_type: "Mode".to_string(), + value: EMPTY, + accumulator: true, + } +} + +/// creates configuration for distance state variables +pub fn multimodal_distance_variable_config( + output_unit: Option, +) -> StateVariableConfig { + StateVariableConfig::Distance { + initial: Length::ZERO, + accumulator: true, + output_unit, + } +} + +/// creates configuration for time state variables +pub fn multimodal_time_variable_config(output_unit: Option) -> StateVariableConfig { + StateVariableConfig::Time { + initial: Time::ZERO, + accumulator: true, + output_unit, + } +} + +/// creates configuration for route_id state variables +pub fn route_id_input_feature() -> InputFeature { + InputFeature::Custom { + name: "route_id".to_string(), + unit: "signed_integer".to_string(), + } +} + +/// creates configuration for route_id state variables +pub fn route_id_variable_config() -> StateVariableConfig { + StateVariableConfig::Custom { + custom_type: "RouteId".to_string(), + value: EMPTY, + accumulator: true, + } +} diff --git a/rust/bambam-core/src/model/time_bin.rs b/rust/bambam-core/src/model/time_bin.rs new file mode 100644 index 00000000..b59f68d5 --- /dev/null +++ b/rust/bambam-core/src/model/time_bin.rs @@ -0,0 +1,47 @@ +use super::bambam_ops; +use routee_compass_core::model::state::{StateModel, StateModelError, StateVariable}; +use serde::{Deserialize, Serialize}; +use uom::si::f64::Time; + +/// a configuration describing the time bounds for a "ring" of an isochrone. +/// time values are in minutes. +#[derive(Deserialize, Serialize, Clone, Debug)] +pub struct TimeBin { + pub min_time: u64, + pub max_time: u64, +} + +impl TimeBin { + // construct a TimeBin. if no min is provided, use "0" minutes + pub fn new(min: Option, max: u64) -> TimeBin { + TimeBin { + min_time: min.unwrap_or_default(), + max_time: max, + } + } + + pub fn key(&self) -> String { + format!("{}", self.max_time) + } + + /// grab the time bin's lower bound as a Time value in a specified time unit + pub fn min_time(&self) -> Time { + Time::new::(self.min_time as f64) + } + + /// grab the time bin's upper bound as a Time value in a specified time unit + pub fn max_time(&self) -> Time { + Time::new::(self.max_time as f64) + } + + pub fn state_time_within_bin( + &self, + state: &[StateVariable], + state_model: &StateModel, + ) -> Result { + let time = bambam_ops::get_reachability_time(state, state_model)?; + let minutes = time.get::() as u64; + let within_bin = self.min_time <= minutes && minutes < self.max_time; + Ok(within_bin) + } +} diff --git a/rust/bambam-core/src/util/date_deserialization_ops.rs b/rust/bambam-core/src/util/date_deserialization_ops.rs new file mode 100644 index 00000000..4f605a97 --- /dev/null +++ b/rust/bambam-core/src/util/date_deserialization_ops.rs @@ -0,0 +1,29 @@ +use chrono::{NaiveDate, NaiveDateTime, ParseResult}; +use serde::de::Error; +use serde::Deserialize; +use serde::Deserializer; + +pub const APP_DATETIME_FORMAT: &str = "%Y-%m-%d %H:%M:%S"; +pub const APP_DATE_FORMAT: &str = "%Y-%m-%d"; + +pub fn naive_date_to_str(date_str: &str) -> ParseResult { + chrono::NaiveDate::parse_from_str(date_str, APP_DATE_FORMAT) +} + +pub fn deserialize_naive_datetime<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let date_str: String = String::deserialize(deserializer)?; + chrono::NaiveDateTime::parse_from_str(&date_str, APP_DATETIME_FORMAT) + .map_err(|e| D::Error::custom(format!("Invalid datetime format: {e}"))) +} + +pub fn deserialize_naive_date<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let date_str: String = String::deserialize(deserializer)?; + naive_date_to_str(&date_str) + .map_err(|e| D::Error::custom(format!("Invalid datetime format: {e}"))) +} diff --git a/rust/bambam-core/src/util/geo_utils.rs b/rust/bambam-core/src/util/geo_utils.rs new file mode 100644 index 00000000..3f7aaf10 --- /dev/null +++ b/rust/bambam-core/src/util/geo_utils.rs @@ -0,0 +1,32 @@ +use geo::Centroid; +use geo::{Geometry, Point}; +use rstar::RTreeObject; +use rstar::AABB; + +/// creates an envelope from a geometry using assumptions that +/// - points, linestrings, polygons can have their bboxes be their envelopes +/// - other geometry types can use their centroids +/// +/// since a centroid may not exist (for example, empty geometries), the result may be None +/// +/// # Arguments +/// +/// * `geometry` - value to create an envelope from +/// +/// # Returns +/// +/// * an envelope if possible, otherwise None +pub fn get_centroid_as_envelope(geometry: &Geometry) -> Option>> { + match geometry { + Geometry::Point(g) => Some(g.envelope()), + Geometry::Line(g) => Some(g.envelope()), + Geometry::LineString(g) => Some(g.envelope()), + Geometry::Polygon(g) => Some(g.envelope()), + Geometry::MultiPoint(g) => g.centroid().map(AABB::from_point), + Geometry::MultiLineString(g) => g.centroid().map(AABB::from_point), + Geometry::MultiPolygon(g) => g.centroid().map(AABB::from_point), + Geometry::GeometryCollection(g) => g.centroid().map(AABB::from_point), + Geometry::Rect(g) => Some(AABB::from_point(g.centroid())), + Geometry::Triangle(g) => Some(AABB::from_point(g.centroid())), + } +} diff --git a/rust/bambam-core/src/util/mod.rs b/rust/bambam-core/src/util/mod.rs new file mode 100644 index 00000000..c40b02d4 --- /dev/null +++ b/rust/bambam-core/src/util/mod.rs @@ -0,0 +1,3 @@ +pub mod date_deserialization_ops; +pub mod geo_utils; +pub mod polygonal_rtree; diff --git a/rust/bambam-core/src/util/polygonal_rtree.rs b/rust/bambam-core/src/util/polygonal_rtree.rs new file mode 100644 index 00000000..bb373d46 --- /dev/null +++ b/rust/bambam-core/src/util/polygonal_rtree.rs @@ -0,0 +1,127 @@ +use geo::{Area, BooleanOps, BoundingRect, Geometry, Intersects, Polygon}; +use rstar::{primitives::Rectangle, RTree, RTreeObject, AABB}; +use wkt::ToWkt; + +pub struct Node { + pub geometry: Geometry, + pub rectangle: Rectangle<(f64, f64)>, + pub data: D, +} + +impl Node { + pub fn new(geometry: Geometry, data: D) -> Result, String> { + let rectangle = rect_from_geometry(&geometry)?; + let result = Node { + geometry, + rectangle, + data, + }; + Ok(result) + } +} + +impl RTreeObject for Node { + type Envelope = AABB<(f64, f64)>; + + fn envelope(&self) -> Self::Envelope { + self.rectangle.envelope() + } +} + +pub struct PolygonalRTree(RTree>); + +impl PolygonalRTree { + pub fn new(data: Vec<(Geometry, D)>) -> Result, String> { + let nodes = data + .into_iter() + .map(|(g, d)| Node::new(g, d)) + .collect::, _>>()?; + let tree = RTree::bulk_load(nodes); + Ok(PolygonalRTree(tree)) + } + + /// tests for intersection with polygonal data in this tree. this involves two steps: + /// 1. finding rtree rectangle envelopes that intersect the incoming geometry + /// 2. testing intersection for each discovered geometry bounded by it's rectangle + pub fn intersection<'a>( + &'a self, + g: &'a Geometry, + ) -> Result> + 'a>, String> { + let query = rect_from_geometry(g)?; + let iter = self + .0 + .locate_in_envelope_intersecting(&query.envelope()) + .filter(|node| node.geometry.intersects(g)); + Ok(Box::new(iter)) + } + + pub fn intersection_with_overlap_area<'a>( + &'a self, + query: &'a Geometry, + ) -> Result, f64)>, String> { + // get all polygons in the query geometry + let query_polygons: Vec = match query { + Geometry::Polygon(p) => Ok(vec![p.clone()]), + Geometry::MultiPolygon(mp) => Ok(mp.0.clone()), + // Geometry::GeometryCollection(geometry_collection) => todo!(), + _ => Err(String::from( + "areal proportion query must be performed on polygonal data", + )), + }?; + + // compute the overlap area for each query polygon for each geometry + // found to intersect the query in the rtree + let result = self + .intersection(query)? + .map(|node| { + let overlap_areas = query_polygons + .iter() + .map(|p| { + let area = overlap_area(p, &node.geometry)?; + Ok(area) + }) + .collect::, String>>()?; + let overlap_area: f64 = overlap_areas.into_iter().sum(); + Ok((node, overlap_area)) + }) + .collect::, f64)>, String>>()?; + + Ok(result) + } +} + +/// helper function to create a rectangular rtree envelope for a given geometry +fn rect_from_geometry(g: &Geometry) -> Result, String> { + let bbox_vec = g.bounding_rect().ok_or_else(|| { + format!( + "internal error: cannot get bounds of geometry: '{}'", + g.to_wkt() + ) + })?; + + let envelope = Rectangle::from_corners(bbox_vec.min().x_y(), bbox_vec.max().x_y()); + Ok(envelope) +} + +/// helper for computing the overlap area, which is the area that two geometries have in common +/// (the intersection). +fn overlap_area(query: &Polygon, overlapping: &Geometry) -> Result { + match overlapping { + Geometry::Polygon(overlap_p) => { + let overlap = query.intersection(overlap_p); + let overlap_area: f64 = overlap.iter().map(|p| p.unsigned_area()).sum(); + Ok(overlap_area) + } + Geometry::MultiPolygon(mp) => { + let overlaps = mp + .iter() + .map(|overlapping_p| { + overlap_area(query, &geo::Geometry::Polygon(overlapping_p.clone())) + }) + .collect::, String>>()?; + let overlap_area: f64 = overlaps.into_iter().sum(); + Ok(overlap_area) + } + _ => Err(String::from("polygonal rtree node must be polygonal!")), + } +} diff --git a/rust/bambam-gtfs/src/model/mod.rs b/rust/bambam-gtfs/src/model/mod.rs new file mode 100644 index 00000000..f8e56a41 --- /dev/null +++ b/rust/bambam-gtfs/src/model/mod.rs @@ -0,0 +1 @@ +pub mod traversal; diff --git a/rust/bambam-gtfs/src/model/traversal/mod.rs b/rust/bambam-gtfs/src/model/traversal/mod.rs new file mode 100644 index 00000000..4fb0cb9c --- /dev/null +++ b/rust/bambam-gtfs/src/model/traversal/mod.rs @@ -0,0 +1 @@ +pub mod transit; diff --git a/rust/bambam-gtfs/src/model/traversal/transit/builder.rs b/rust/bambam-gtfs/src/model/traversal/transit/builder.rs new file mode 100644 index 00000000..fca65e7f --- /dev/null +++ b/rust/bambam-gtfs/src/model/traversal/transit/builder.rs @@ -0,0 +1,30 @@ +use std::sync::Arc; + +use crate::model::traversal::transit::{ + config::TransitTraversalConfig, engine::TransitTraversalEngine, + service::TransitTraversalService, +}; +use routee_compass_core::model::traversal::{ + TraversalModelBuilder, TraversalModelError, TraversalModelService, +}; + +pub struct TransitTraversalBuilder {} + +impl TraversalModelBuilder for TransitTraversalBuilder { + fn build( + &self, + parameters: &serde_json::Value, + ) -> Result, TraversalModelError> { + let config: TransitTraversalConfig = + serde_json::from_value(parameters.clone()).map_err(|e| { + TraversalModelError::BuildError(format!( + "failed to read transit_traversal configuration: {e}" + )) + })?; + + let engine = TransitTraversalEngine::try_from(config)?; + let service = TransitTraversalService::new(Arc::new(engine)); + + Ok(Arc::new(service)) + } +} diff --git a/rust/bambam-gtfs/src/model/traversal/transit/config.rs b/rust/bambam-gtfs/src/model/traversal/transit/config.rs new file mode 100644 index 00000000..9af3bb32 --- /dev/null +++ b/rust/bambam-gtfs/src/model/traversal/transit/config.rs @@ -0,0 +1,21 @@ +// Questions +// - Should the engine create the edges in compass? No +// - If we are already in the same route, should we make transit_boarding_time 0 but still the travel time = dst_arrival - current_time +// - If Schedules = Box<[Schedule]>, how do we access the correct schedule if I have an edge_id? edge_id is usize + +use serde::{Deserialize, Serialize}; + +use crate::model::traversal::transit::schedule_loading_policy::ScheduleLoadingPolicy; + +#[derive(Debug, Serialize, Deserialize)] +pub struct TransitTraversalConfig { + /// edges-schedules file path from gtfs preprocessing + pub edges_schedules_input_file: String, + /// metadata file path from gtfs preprocessing + pub gtfs_metadata_input_file: String, + /// policy by which to prune departures when reading schedules + pub schedule_loading_policy: ScheduleLoadingPolicy, + /// if provided, overrides the metadata entry for fully-qualified + /// route ids, in the case of running multiple transit models simultaneously. + pub route_ids_input_file: Option, +} diff --git a/rust/bambam-gtfs/src/model/traversal/transit/engine.rs b/rust/bambam-gtfs/src/model/traversal/transit/engine.rs new file mode 100644 index 00000000..9fecbaf4 --- /dev/null +++ b/rust/bambam-gtfs/src/model/traversal/transit/engine.rs @@ -0,0 +1,439 @@ +use std::{ + cmp, + collections::HashMap, + fs::File, + io::BufReader, + path::{Path, PathBuf}, + sync::Arc, +}; + +use crate::model::traversal::transit::{ + config::TransitTraversalConfig, + metadata::{self, GtfsArchiveMetadata}, + schedule::{Departure, Schedule}, + schedule_loading_policy::{self, ScheduleLoadingPolicy}, + transit_ops, +}; +use bambam_core::model::state::{MultimodalMapping, MultimodalStateMapping}; +use chrono::{NaiveDate, NaiveDateTime}; +use flate2::bufread::GzDecoder; +use routee_compass_core::{model::traversal::TraversalModelError, util::fs::read_utils}; +use serde::{Deserialize, Serialize}; +use skiplist::OrderedSkipList; +use uom::si::f64::Time; + +pub struct TransitTraversalEngine { + pub edge_schedules: Box<[HashMap]>, + pub date_mapping: HashMap>, +} + +impl TransitTraversalEngine { + pub fn get_next_departure( + &self, + edge_id: usize, + current_datetime: &NaiveDateTime, + ) -> Result<(i64, Departure), TraversalModelError> { + let departures_skiplists = + self.edge_schedules + .get(edge_id) + .ok_or(TraversalModelError::InternalError(format!( + "EdgeId {edge_id} exceeds schedules length" + )))?; + + // Iterate over all routes that have schedules on this edge + let result = departures_skiplists + .iter() + .map(|(route_id_label, skiplist)| { + // reconcile with any date mappings. used to address date gaps across all GTFS archives. + let search_datetime = transit_ops::apply_date_mapping( + &self.date_mapping, + route_id_label, + current_datetime, + ); + + // Query the skiplist + // We need to create the struct shell to be able to search the + // skiplist. I tried several other approaches but I think this is the cleanest + let search_query = Departure::construct_query(search_datetime); + + // get next or infinity. if infinity cannot be created: error + let next_route_departure = skiplist + .lower_bound(std::ops::Bound::Included(&search_query)) + .cloned() + .unwrap_or(Departure::infinity()); + + // Undo datemapping + let next_route_departure = transit_ops::reverse_date_mapping( + current_datetime, + &search_datetime, + next_route_departure, + ); + + // Return next departure for route + (*route_id_label, next_route_departure) + }) + .min_by_key(|(_, departure)| departure.dst_arrival_time) + .ok_or(TraversalModelError::InternalError(format!( + "failed to find next departure: schedules for edge_id {edge_id} appear to be empty" + )))?; + Ok(result) + } +} + +impl TryFrom for TransitTraversalEngine { + type Error = TraversalModelError; + + fn try_from(value: TransitTraversalConfig) -> Result { + log::debug!( + "loading transit traversal model from {}", + value.gtfs_metadata_input_file + ); + + // Deserialize metadata file + let file = File::open(value.gtfs_metadata_input_file).map_err(|e| { + TraversalModelError::BuildError(format!("Failed to read metadata file: {e}")) + })?; + let metadata: GtfsArchiveMetadata = + serde_json::from_reader(BufReader::new(file)).map_err(|e| { + TraversalModelError::BuildError(format!("Failed to read metadata file: {e}")) + })?; + + let route_id_to_state = match &value.route_ids_input_file { + Some(route_ids_input_file) => MultimodalStateMapping::from_enumerated_category_file( + Path::new(&route_ids_input_file), + )?, + None => MultimodalStateMapping::new(&metadata.fq_route_ids)?, + }; + + log::debug!( + "loaded {} fq route ids into mapping", + route_id_to_state.n_categories() + ); + + // re-map hash map keys from categorical to i64 label. + let date_mapping = build_label_to_date_mapping(&metadata, &route_id_to_state)?; + log::debug!("loaded date mapping with {} entries", date_mapping.len()); + + let edge_schedules = read_schedules_from_file( + value.edges_schedules_input_file, + Arc::new(route_id_to_state), + value.schedule_loading_policy, + )?; + + Ok(Self { + edge_schedules, + date_mapping, + }) + } +} + +/// This function assumes that edge_id's are dense. If any edge_id is skipped, the transformation from +/// a HashMap into Vec will fail +fn read_schedules_from_file( + filename: String, + route_mapping: Arc, + schedule_loading_policy: ScheduleLoadingPolicy, +) -> Result]>, TraversalModelError> { + // Reading csv + let rows: Box<[super::RawScheduleRow]> = + read_utils::from_csv(&Path::new(&filename), true, None, None).map_err(|e| { + TraversalModelError::BuildError(format!("Error creating reader to schedules file: {e}")) + })?; + + log::debug!("{filename} - loaded {} raw schedule rows", rows.len()); + + // Deserialize rows according to their edge_id + let mut schedules: HashMap> = HashMap::new(); + for record in rows { + let route_i64 = route_mapping.get_label(&record.fully_qualified_id).ok_or( + TraversalModelError::BuildError(format!( + "Cannot find route id mapping for string {}", + record.fully_qualified_id.clone() + )), + )?; + + // This step creates an empty skiplist for every edge we see, even if we don't load any departures to it + let schedule_skiplist = schedules + .entry(record.edge_id) + .or_default() + .entry(*route_i64) + .or_default(); + schedule_loading_policy.insert_if_valid( + schedule_skiplist, + Departure { + src_departure_time: record.src_departure_time, + dst_arrival_time: record.dst_arrival_time, + }, + ); + } + log::debug!( + "{filename} - built schedule lookup for {} routes", + schedules.len() + ); + + // Observe total number of keys (edge_ids) + let n_edges = schedules.keys().len(); + + // Re-arrange all into a dense boxed slice + let out = (0..n_edges) + .map(|i| { + schedules + .remove(&i) // TIL: `remove` returns an owned value, consuming the hashmap + .ok_or(TraversalModelError::BuildError(format!( + "Invalid schedules file. Missing edge_id {i} when the maximum edge_id is {n_edges}" + ))) + }) + .collect::>, TraversalModelError>>()?; + + log::debug!("{filename} - built skip lists for {} routes", out.len()); + + Ok(out.into_boxed_slice()) +} + +/// helper function to construct a mapping from categorical label (a i64 StateVariable) +/// into a date mapping. +fn build_label_to_date_mapping( + metadata: &GtfsArchiveMetadata, + route_id_to_state: &MultimodalStateMapping, +) -> Result>, TraversalModelError> { + let mapped = metadata + .fq_route_ids + .iter() + .map(|route_id| { + let label = route_id_to_state.get_label(route_id) + .ok_or_else(|| { + // this is only possible if the fq_route_ids are not the same as the dataset + // that created the state mapping. + TraversalModelError::BuildError( + "fully-qualified route id '{route_id}' has no entry in enumeration table from file".to_string() + ) + })?; + let mapping = match metadata.date_mapping.get(route_id) { + None => return Ok(None), + Some(mapping) => mapping, + }; + Ok(Some((*label, mapping.clone()))) + }) + .collect::, TraversalModelError>>()?; + let result = mapped.into_iter().flatten().collect(); + Ok(result) +} + +#[cfg(test)] +mod test { + + use crate::model::traversal::transit::{ + engine::TransitTraversalEngine, + schedule::{Departure, Schedule}, + }; + use chrono::{Months, NaiveDate, NaiveDateTime}; + use std::collections::HashMap; + use std::str::FromStr; + + fn internal_date(string: &str) -> NaiveDateTime { + NaiveDateTime::parse_from_str(&format!("20250101 {string}"), "%Y%m%d %H:%M:%S").unwrap() + } + + fn get_dummy_engine( + date_mapping: Option>>, + ) -> TransitTraversalEngine { + // There are two edges that reverse each other and two routes that move across them + // Route 1: + // 16:00 - 16:05 (A-B) -> 16:05 - 16:10 (B-A) -> 16:10 - 16:25 dwell -> 16:25 - 16:30 (A-B) -> 16:30 - 16:35 (B-A) + // + // Route 2: + // 16:15 - 16:45 (A-B) -> 16:45 - 17:00 (B-A) + let schedules: Vec> = vec![ + HashMap::from([ + ( + 0, + dummy_schedule(&[("16:00:00", "16:05:00"), ("16:25:00", "16:30:00")]), + ), + (1, dummy_schedule(&[("16:15:00", "16:45:00")])), + ]), + HashMap::from([ + ( + 0, + dummy_schedule(&[("16:05:00", "16:10:00"), ("16:30:00", "16:35:00")]), + ), + (1, dummy_schedule(&[("16:45:00", "17:00:00")])), + ]), + ]; + + TransitTraversalEngine { + edge_schedules: schedules.into_boxed_slice(), + date_mapping: date_mapping.unwrap_or_default(), + } + } + + fn dummy_schedule(times: &[(&str, &str)]) -> Schedule { + let departures = times.iter().map(|(src, dst)| Departure { + src_departure_time: internal_date(src), + dst_arrival_time: internal_date(dst), + }); + Schedule::from_iter(departures) + } + + #[test] + fn test_get_next_departure() { + let engine = get_dummy_engine(None); + + let mut current_edge: usize = 0; + let mut current_time = internal_date("15:50:00"); + let mut next_tuple = engine + .get_next_departure(current_edge, ¤t_time) + .unwrap(); + let mut next_route = next_tuple.0; + let mut next_departure = next_tuple.1; + + assert_eq!(next_route, 0); + assert_eq!(next_departure.src_departure_time, internal_date("16:00:00")); + + // Traverse 3 times the next edge + for i in 0..3 { + next_tuple = engine + .get_next_departure(current_edge, ¤t_time) + .unwrap(); + next_route = next_tuple.0; + next_departure = next_tuple.1; + + current_time = next_departure.dst_arrival_time; + current_edge = 1 - current_edge; + } + + assert_eq!(next_route, 0); + assert_eq!(current_time, internal_date("16:30:00")); + + // Ride transit one more time + next_tuple = engine + .get_next_departure(current_edge, ¤t_time) + .unwrap(); + next_route = next_tuple.0; + next_departure = next_tuple.1; + + current_time = next_departure.dst_arrival_time; + current_edge = 1 - current_edge; + + // If we wait now, we will find there are no more departures + next_tuple = engine + .get_next_departure(current_edge, ¤t_time) + .unwrap(); + next_route = next_tuple.0; + next_departure = next_tuple.1; + assert_eq!(next_departure, Departure::infinity()); + } + + #[test] + fn test_schedule_from_iter() { + let departures = vec![ + Departure { + src_departure_time: internal_date("10:00:00"), + dst_arrival_time: internal_date("10:15:00"), + }, + Departure { + src_departure_time: internal_date("08:00:00"), + dst_arrival_time: internal_date("08:20:00"), + }, + Departure { + src_departure_time: internal_date("09:00:00"), + dst_arrival_time: internal_date("09:10:00"), + }, + ]; + + let schedule = Schedule::from_iter(departures); + assert_eq!(schedule.len(), 3); + + // Should be ordered automatically + let ordered: Vec<&Departure> = schedule.iter().collect(); + assert_eq!(ordered[0].src_departure_time, internal_date("08:00:00")); + assert_eq!(ordered[1].src_departure_time, internal_date("09:00:00")); + assert_eq!(ordered[2].src_departure_time, internal_date("10:00:00")); + } + + #[test] + fn test_schedule_comprehensive_search_scenario() { + // Create a realistic bus schedule with multiple departures throughout the day + let schedule = dummy_schedule(&[ + ("06:00:00", "06:25:00"), // Early morning + ("06:30:00", "06:55:00"), + ("07:00:00", "07:25:00"), // Rush hour starts + ("07:15:00", "07:40:00"), + ("07:30:00", "07:55:00"), + ("08:00:00", "08:25:00"), + ("09:00:00", "09:25:00"), // Off-peak + ("10:00:00", "10:25:00"), + ("17:00:00", "17:25:00"), // Evening rush + ("17:30:00", "17:55:00"), + ("18:00:00", "18:25:00"), + ("22:00:00", "22:25:00"), // Late evening + ]); + + // Test various search scenarios + let test_cases = vec![ + ("05:30:00", Some("06:00:00")), // Before service starts + ("06:00:00", Some("06:00:00")), // Exact match + ("06:10:00", Some("06:30:00")), // Between departures + ("07:20:00", Some("07:30:00")), // During rush hour + ("12:00:00", Some("17:00:00")), // Large gap in service + ("21:00:00", Some("22:00:00")), // Evening service + ("23:00:00", None), // After service ends + ]; + + for (search_time, expected_time) in test_cases { + let search_departure = Departure { + src_departure_time: internal_date(search_time), + dst_arrival_time: internal_date(search_time), + }; + + let result = schedule.lower_bound(std::ops::Bound::Included(&search_departure)); + + match expected_time { + Some(expected) => { + assert!( + result.is_some(), + "Expected departure at {expected} for search time {search_time}" + ); + assert_eq!( + result.unwrap().src_departure_time, + internal_date(expected), + "Search time {search_time} should find departure at {expected}" + ); + } + None => { + assert!( + result.is_none(), + "Expected no departure for search time {search_time}" + ); + } + } + } + } + + #[test] + fn test_positive_travel_time_after_datemapping() { + // Instantiating a datemapping that maps to a day before + let ref_date = NaiveDate::parse_from_str("20250101", "%Y%m%d").unwrap(); + let current_date = NaiveDate::parse_from_str("20250102", "%Y%m%d").unwrap(); + let single_date_mapping: HashMap = + [(current_date, ref_date)].into_iter().collect(); + + let date_mapping: Option>> = Some( + [ + (0, single_date_mapping.clone()), + (1, single_date_mapping.clone()), + ] + .into_iter() + .collect(), + ); + let engine = get_dummy_engine(date_mapping); + + let mut current_edge: usize = 0; + let mut current_time = + NaiveDateTime::parse_from_str("20250102 15:55:00", "%Y%m%d %H:%M:%S").unwrap(); + let mut next_tuple = engine + .get_next_departure(current_edge, ¤t_time) + .unwrap(); + + assert!((next_tuple.1.src_departure_time - current_time).as_seconds_f64() >= 0.); + } +} diff --git a/rust/bambam-gtfs/src/model/traversal/transit/metadata.rs b/rust/bambam-gtfs/src/model/traversal/transit/metadata.rs new file mode 100644 index 00000000..f5119495 --- /dev/null +++ b/rust/bambam-gtfs/src/model/traversal/transit/metadata.rs @@ -0,0 +1,73 @@ +use bambam_core::util::date_deserialization_ops::naive_date_to_str; +use chrono::{Duration, NaiveDate}; +use routee_compass_core::model::traversal::TraversalModelError; +use serde::de::Error; +use serde::{Deserialize, Deserializer, Serialize}; +use serde_json::Value; +use std::collections::HashMap; + +/// fields from the metadata JSON file that are relevant for loading transit traversal models. +/// additional fields exist +#[derive(Serialize, Deserialize)] +pub struct GtfsArchiveMetadata { + /// direct output of GTFS agencies.txt + pub agencies: Vec, + /// direct output of GTFS feed_info.txt + pub feed_info: Vec, + /// time required to read this archive using bambam-gtfs + #[serde(deserialize_with = "deserialize_duration")] + pub read_duration: Duration, + /// direct output of GTFS calendar.txt by service_id + pub calendar: Value, + /// direct output of GTFS calendar_dates.txt by service_id + pub calendar_dates: Value, + /// Mapping from target date to available date for each route_id + #[serde(deserialize_with = "deserialize_date_mapping")] + pub date_mapping: HashMap>, + /// List of unique (fully-qualified) route identifiers used in the schedules + pub fq_route_ids: Vec, +} + +#[derive(Deserialize)] +struct DurationJson { + pub secs: i64, + pub nanos: u32, +} + +fn deserialize_duration<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let DurationJson { secs, nanos } = DurationJson::deserialize(deserializer)?; + chrono::Duration::new(secs, nanos).ok_or_else(|| { + D::Error::custom(format!( + "invalid duration value with secs {secs}, nanos {nanos}" + )) + }) +} + +fn deserialize_date_mapping<'de, D>( + deserializer: D, +) -> Result>, D::Error> +where + D: Deserializer<'de>, +{ + let original_map = HashMap::>::deserialize(deserializer)?; + + // Convert inner maps to NaiveDate keys/values + let mut out: HashMap> = + HashMap::with_capacity(original_map.len()); + for (route_id, inner) in original_map { + let mut parsed_inner = HashMap::with_capacity(inner.len()); + for (k_str, v_str) in inner { + let k = naive_date_to_str(&k_str) + .map_err(|e| D::Error::custom(format!("failed to deserialize date mapping for route_id `{route_id}`: invalid date key `{k_str}`: {e}")))?; + let v = naive_date_to_str(&v_str) + .map_err(|e| D::Error::custom(format!("failed to deserialize date mapping for route_id `{route_id}`: invalid date value `{v_str}`: {e}")))?; + parsed_inner.insert(k, v); + } + out.insert(route_id, parsed_inner); + } + + Ok(out) +} diff --git a/rust/bambam-gtfs/src/model/traversal/transit/mod.rs b/rust/bambam-gtfs/src/model/traversal/transit/mod.rs new file mode 100644 index 00000000..bcf0843a --- /dev/null +++ b/rust/bambam-gtfs/src/model/traversal/transit/mod.rs @@ -0,0 +1,22 @@ +mod builder; +mod config; +mod engine; +mod metadata; +mod model; +mod query; +mod raw_schedule_row; +mod schedule; +mod schedule_loading_policy; +mod service; + +pub mod transit_ops; +pub use builder::TransitTraversalBuilder; +pub use config::TransitTraversalConfig; +pub use engine::TransitTraversalEngine; +pub use metadata::GtfsArchiveMetadata; +pub use model::TransitTraversalModel; +pub use query::TransitTraversalQuery; +pub use raw_schedule_row::RawScheduleRow; +pub use schedule::{Departure, Schedule}; +pub use schedule_loading_policy::ScheduleLoadingPolicy; +pub use service::TransitTraversalService; diff --git a/rust/bambam-gtfs/src/model/traversal/transit/model.rs b/rust/bambam-gtfs/src/model/traversal/transit/model.rs new file mode 100644 index 00000000..04badd32 --- /dev/null +++ b/rust/bambam-gtfs/src/model/traversal/transit/model.rs @@ -0,0 +1,165 @@ +use std::sync::Arc; + +use crate::model::traversal::transit::transit_ops; +use crate::model::traversal::transit::{engine::TransitTraversalEngine, schedule::Departure}; +use bambam_core::model::bambam_state::{self, ROUTE_ID}; +use bambam_core::model::state::variable::EMPTY; +use chrono::{Duration, NaiveDate, NaiveDateTime}; +use routee_compass_core::model::traversal::TraversalModelError; +use routee_compass_core::model::{ + state::StateVariableConfig, + traversal::{default::fieldname, TraversalModel}, +}; +use uom::{ + si::f64::{Length, Time}, + ConstZero, +}; + +pub struct TransitTraversalModel { + engine: Arc, + start_datetime: NaiveDateTime, + record_dwell_time: bool, +} + +impl TransitTraversalModel { + pub fn new( + engine: Arc, + start_datetime: NaiveDateTime, + record_dwell_time: bool, + ) -> Self { + Self { + engine, + start_datetime, + record_dwell_time, + } + } +} + +impl TraversalModel for TransitTraversalModel { + fn name(&self) -> String { + "transit_traversal".to_string() + } + + fn input_features(&self) -> Vec { + vec![] + } + + fn output_features( + &self, + ) -> Vec<( + String, + routee_compass_core::model::state::StateVariableConfig, + )> { + let mut out = vec![ + ( + String::from(fieldname::TRIP_TIME), + StateVariableConfig::Time { + initial: Time::ZERO, + output_unit: None, + accumulator: true, + }, + ), + ( + String::from(fieldname::EDGE_TIME), + StateVariableConfig::Time { + initial: Time::ZERO, + output_unit: None, + accumulator: false, + }, + ), + ( + String::from(bambam_state::ROUTE_ID), + StateVariableConfig::Custom { + custom_type: "RouteId".to_string(), + value: EMPTY, + accumulator: true, + }, + ), + ( + String::from(bambam_state::TRANSIT_BOARDING_TIME), + StateVariableConfig::Time { + initial: Time::ZERO, + accumulator: false, + output_unit: None, + }, + ), + ]; + + if self.record_dwell_time { + out.push(( + String::from(bambam_state::DWELL_TIME), + StateVariableConfig::Time { + initial: Time::ZERO, + accumulator: false, + output_unit: None, + }, + )); + } + + out + } + + fn traverse_edge( + &self, + trajectory: ( + &routee_compass_core::model::network::Vertex, + &routee_compass_core::model::network::Edge, + &routee_compass_core::model::network::Vertex, + ), + state: &mut Vec, + tree: &routee_compass_core::algorithm::search::SearchTree, + state_model: &routee_compass_core::model::state::StateModel, + ) -> Result<(), routee_compass_core::model::traversal::TraversalModelError> { + let current_edge_id = trajectory.1.edge_id; + let current_route_id = state_model.get_custom_i64(state, bambam_state::ROUTE_ID)?; + let current_datetime = + transit_ops::get_current_time(&self.start_datetime, state, state_model)?; + + // get the next departure. + // in the case that no schedules are found, a sentinel value is returned set + // far in the future (an "infinity" value). this indicates that this edge should not + // have been accepted by the FrontierModel. but at this point, we do not have a + // transit frontier model, so "infinity" must solve the same problem. + let (next_route, next_departure) = self + .engine + .get_next_departure(current_edge_id.as_usize(), ¤t_datetime)?; + let next_departure_route_id = next_route; + + // update the state. a bunch of features are modified here. + // NOTE: wait_time is "time waiting in the transit stop" OR "time waiting sitting on the bus during scheduled dwell time" + let wait_time = Time::new::( + (next_departure.src_departure_time - current_datetime).as_seconds_f64(), + ); + let travel_time = Time::new::( + (next_departure.dst_arrival_time - next_departure.src_departure_time).as_seconds_f64(), + ); + let total_time = wait_time + travel_time; + + // Update state + state_model.add_time(state, fieldname::TRIP_TIME, &total_time); + state_model.add_time(state, fieldname::EDGE_TIME, &total_time); + state_model.set_custom_i64(state, ROUTE_ID, &next_departure_route_id); + + // TRANSIT_BOARDING_TIME accumulates time waiting at transit stops, but not dwell time + if current_route_id != next_departure_route_id { + state_model.add_time(state, bambam_state::TRANSIT_BOARDING_TIME, &wait_time); + } else if self.record_dwell_time { + state_model.add_time(state, bambam_state::DWELL_TIME, &wait_time); + } + + Ok(()) + } + + fn estimate_traversal( + &self, + od: ( + &routee_compass_core::model::network::Vertex, + &routee_compass_core::model::network::Vertex, + ), + state: &mut Vec, + tree: &routee_compass_core::algorithm::search::SearchTree, + state_model: &routee_compass_core::model::state::StateModel, + ) -> Result<(), routee_compass_core::model::traversal::TraversalModelError> { + Ok(()) + } +} diff --git a/rust/bambam-gtfs/src/model/traversal/transit/query.rs b/rust/bambam-gtfs/src/model/traversal/transit/query.rs new file mode 100644 index 00000000..62e2a5ae --- /dev/null +++ b/rust/bambam-gtfs/src/model/traversal/transit/query.rs @@ -0,0 +1,12 @@ +use bambam_core::util::date_deserialization_ops::deserialize_naive_datetime; +use chrono::NaiveDateTime; +use serde::{Deserialize, Serialize}; +use uom::si::f64::Time; + +#[derive(Serialize, Deserialize)] +pub struct TransitTraversalQuery { + #[serde(deserialize_with = "deserialize_naive_datetime")] + pub start_datetime: NaiveDateTime, // Fix deserialization + /// If true, we maintain a DWELL_TIME state variable + pub record_dwell_time: Option, +} diff --git a/rust/bambam-gtfs/src/model/traversal/transit/raw_schedule_row.rs b/rust/bambam-gtfs/src/model/traversal/transit/raw_schedule_row.rs new file mode 100644 index 00000000..fb61d1ac --- /dev/null +++ b/rust/bambam-gtfs/src/model/traversal/transit/raw_schedule_row.rs @@ -0,0 +1,13 @@ +use chrono::NaiveDateTime; +use serde::{Deserialize, Serialize}; + +/// record type storing a single scheduled departure and arrival +/// within a route. +#[derive(Debug, Deserialize, Serialize)] +pub struct RawScheduleRow { + pub edge_id: usize, + /// fully-qualified route id + pub fully_qualified_id: String, + pub src_departure_time: NaiveDateTime, + pub dst_arrival_time: NaiveDateTime, +} diff --git a/rust/bambam-gtfs/src/model/traversal/transit/schedule.rs b/rust/bambam-gtfs/src/model/traversal/transit/schedule.rs new file mode 100644 index 00000000..14700c40 --- /dev/null +++ b/rust/bambam-gtfs/src/model/traversal/transit/schedule.rs @@ -0,0 +1,271 @@ +use std::ops::Add; + +use chrono::{Duration, Months, NaiveDateTime, TimeDelta}; +use skiplist::OrderedSkipList; + +/// a schedule contains an ordered list of [`Departure`] values. +pub type Schedule = OrderedSkipList; + +/// a single departure from a src location, recorded as its pair of +/// departure time from here and arrival time at some dst location. +#[derive(Debug, Clone, Eq, Copy)] +pub struct Departure { + pub src_departure_time: NaiveDateTime, + pub dst_arrival_time: NaiveDateTime, +} + +impl Departure { + pub fn construct_query(datetime: NaiveDateTime) -> Self { + Self { + src_departure_time: datetime, + dst_arrival_time: datetime, + } + } + + /// represent infinity in the time space of departures + pub fn infinity() -> Self { + Departure { + src_departure_time: NaiveDateTime::MAX, + dst_arrival_time: NaiveDateTime::MAX, + } + } + + /// the departure is placed at positive infinity. occurs + /// when adding extreme TimeDelta values. + pub fn is_pos_infinity(&self) -> bool { + self.src_departure_time == NaiveDateTime::MAX || self.dst_arrival_time == NaiveDateTime::MAX + } + + /// the departure is placed at negative infinity. occurs + /// when adding extreme TimeDelta values. + pub fn is_neg_infinity(&self) -> bool { + self.src_departure_time == NaiveDateTime::MIN || self.dst_arrival_time == NaiveDateTime::MIN + } +} + +impl Add<&TimeDelta> for Departure { + type Output = Departure; + /// adds to a Departure. clamps at absolute MIN or MAX time values. + fn add(self, rhs: &TimeDelta) -> Self::Output { + let src_departure_time = add_time_to_datetime(&self.src_departure_time, rhs); + let dst_arrival_time = add_time_to_datetime(&self.dst_arrival_time, rhs); + Departure { + src_departure_time, + dst_arrival_time, + } + } +} + +impl PartialEq for Departure { + fn eq(&self, other: &Self) -> bool { + self.src_departure_time == other.src_departure_time + && self.dst_arrival_time == other.dst_arrival_time + } +} + +impl PartialOrd for Departure { + fn partial_cmp(&self, other: &Self) -> Option { + self.src_departure_time + .partial_cmp(&other.src_departure_time) + } +} + +impl Ord for Departure { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.src_departure_time.cmp(&other.src_departure_time) + } +} + +/// Adds a time delta to a datetime, clamping to MIN/MAX on overflow. +/// +/// # Arguments +/// * `date_time` - The base datetime +/// * `time_delta` - The duration to add (can be negative) +/// +/// # Returns +/// - The sum if it fits within NaiveDateTime's range +/// - NaiveDateTime::MIN if negative overflow occurs +/// - NaiveDateTime::MAX if positive overflow occurs +fn add_time_to_datetime(date_time: &NaiveDateTime, time_delta: &TimeDelta) -> NaiveDateTime { + date_time + .checked_add_signed(*time_delta) + .unwrap_or_else(|| { + if time_delta < &TimeDelta::zero() { + NaiveDateTime::MIN + } else { + NaiveDateTime::MAX + } + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::NaiveDateTime; + + #[test] + fn test_departure_add_normal() { + let departure = Departure { + src_departure_time: NaiveDateTime::parse_from_str( + "2023-06-15 10:00:00", + "%Y-%m-%d %H:%M:%S", + ) + .unwrap(), + dst_arrival_time: NaiveDateTime::parse_from_str( + "2023-06-15 11:00:00", + "%Y-%m-%d %H:%M:%S", + ) + .unwrap(), + }; + let delta = TimeDelta::hours(2); + let result = departure + δ + + assert_eq!( + result.src_departure_time, + NaiveDateTime::parse_from_str("2023-06-15 12:00:00", "%Y-%m-%d %H:%M:%S").unwrap() + ); + assert_eq!( + result.dst_arrival_time, + NaiveDateTime::parse_from_str("2023-06-15 13:00:00", "%Y-%m-%d %H:%M:%S").unwrap() + ); + } + + #[test] + fn test_departure_add_negative() { + let departure = Departure { + src_departure_time: NaiveDateTime::parse_from_str( + "2023-06-15 10:00:00", + "%Y-%m-%d %H:%M:%S", + ) + .unwrap(), + dst_arrival_time: NaiveDateTime::parse_from_str( + "2023-06-15 11:00:00", + "%Y-%m-%d %H:%M:%S", + ) + .unwrap(), + }; + let delta = TimeDelta::hours(-2); + let result = departure + δ + + assert_eq!( + result.src_departure_time, + NaiveDateTime::parse_from_str("2023-06-15 08:00:00", "%Y-%m-%d %H:%M:%S").unwrap() + ); + assert_eq!( + result.dst_arrival_time, + NaiveDateTime::parse_from_str("2023-06-15 09:00:00", "%Y-%m-%d %H:%M:%S").unwrap() + ); + } + + #[test] + fn test_departure_add_overflow_to_max() { + let departure = Departure { + src_departure_time: NaiveDateTime::parse_from_str( + "9999-12-31 23:00:00", + "%Y-%m-%d %H:%M:%S", + ) + .unwrap(), + dst_arrival_time: NaiveDateTime::parse_from_str( + "9999-12-31 23:30:00", + "%Y-%m-%d %H:%M:%S", + ) + .unwrap(), + }; + // Adding a huge duration that will overflow + let delta = TimeDelta::days(365 * 1000000); // 1M years + let result = departure + δ + + assert_eq!( + result.src_departure_time, + NaiveDateTime::MAX, + "Should clamp to MAX on positive overflow" + ); + assert_eq!( + result.dst_arrival_time, + NaiveDateTime::MAX, + "Should clamp to MAX on positive overflow" + ); + } + + #[test] + fn test_departure_add_underflow_to_min() { + let departure = Departure { + src_departure_time: NaiveDateTime::parse_from_str( + "0001-01-01 01:00:00", + "%Y-%m-%d %H:%M:%S", + ) + .unwrap(), + dst_arrival_time: NaiveDateTime::parse_from_str( + "0001-01-01 01:30:00", + "%Y-%m-%d %H:%M:%S", + ) + .unwrap(), + }; + // Subtracting a huge duration that will underflow + let delta = TimeDelta::days(-365 * 1000000); // -1M years + let result = departure + δ + + assert_eq!( + result.src_departure_time, + NaiveDateTime::MIN, + "Should clamp to MIN on negative overflow" + ); + assert_eq!( + result.dst_arrival_time, + NaiveDateTime::MIN, + "Should clamp to MIN on negative overflow" + ); + } + + #[test] + fn test_departure_infinity() { + let inf = Departure::infinity(); + assert!(inf.is_pos_infinity()); + assert_eq!(inf.src_departure_time, NaiveDateTime::MAX); + assert_eq!(inf.dst_arrival_time, NaiveDateTime::MAX); + } + + #[test] + fn test_departure_add_to_infinity_stays_infinity() { + let inf = Departure::infinity(); + let delta = TimeDelta::hours(5); + let result = inf + δ + + // Adding to MAX should stay at MAX + assert_eq!(result.src_departure_time, NaiveDateTime::MAX); + assert_eq!(result.dst_arrival_time, NaiveDateTime::MAX); + assert!(result.is_pos_infinity()); + } + + #[test] + fn test_departure_ordering() { + let early = Departure { + src_departure_time: NaiveDateTime::parse_from_str( + "2023-06-15 10:00:00", + "%Y-%m-%d %H:%M:%S", + ) + .unwrap(), + dst_arrival_time: NaiveDateTime::parse_from_str( + "2023-06-15 11:00:00", + "%Y-%m-%d %H:%M:%S", + ) + .unwrap(), + }; + let late = Departure { + src_departure_time: NaiveDateTime::parse_from_str( + "2023-06-15 12:00:00", + "%Y-%m-%d %H:%M:%S", + ) + .unwrap(), + dst_arrival_time: NaiveDateTime::parse_from_str( + "2023-06-15 13:00:00", + "%Y-%m-%d %H:%M:%S", + ) + .unwrap(), + }; + + assert!(early < late); + assert!(late > early); + assert_eq!(early, early); + } +} diff --git a/rust/bambam-gtfs/src/model/traversal/transit/schedule_loading_policy.rs b/rust/bambam-gtfs/src/model/traversal/transit/schedule_loading_policy.rs new file mode 100644 index 00000000..a13d8ee9 --- /dev/null +++ b/rust/bambam-gtfs/src/model/traversal/transit/schedule_loading_policy.rs @@ -0,0 +1,34 @@ +use chrono::NaiveDateTime; +use routee_compass::plugin::input::default; +use serde::{Deserialize, Serialize}; + +use crate::model::traversal::transit::schedule::{Departure, Schedule}; + +#[derive(Debug, Default, Serialize, Deserialize)] +pub enum ScheduleLoadingPolicy { + #[default] + All, + InDateRange { + start_date: NaiveDateTime, + end_date: NaiveDateTime, + }, +} + +impl ScheduleLoadingPolicy { + pub fn insert_if_valid(&self, schedule_skiplist: &mut Schedule, element: Departure) { + let should_insert = match self { + ScheduleLoadingPolicy::All => true, + ScheduleLoadingPolicy::InDateRange { + start_date, + end_date, + } => { + (element.src_departure_time <= *end_date) + && (*start_date <= element.src_departure_time) + } + }; + + if should_insert { + schedule_skiplist.insert(element); + } + } +} diff --git a/rust/bambam-gtfs/src/model/traversal/transit/service.rs b/rust/bambam-gtfs/src/model/traversal/transit/service.rs new file mode 100644 index 00000000..45776340 --- /dev/null +++ b/rust/bambam-gtfs/src/model/traversal/transit/service.rs @@ -0,0 +1,40 @@ +use std::sync::Arc; + +use routee_compass_core::model::traversal::{ + TraversalModel, TraversalModelError, TraversalModelService, +}; + +use crate::model::traversal::transit::{ + engine::TransitTraversalEngine, model::TransitTraversalModel, query::TransitTraversalQuery, +}; + +pub struct TransitTraversalService { + engine: Arc, +} + +impl TransitTraversalService { + pub fn new(engine: Arc) -> Self { + Self { engine } + } +} + +impl TraversalModelService for TransitTraversalService { + fn build( + &self, + query: &serde_json::Value, + ) -> Result, TraversalModelError> { + let model_query: TransitTraversalQuery = + serde_json::from_value(query.clone()).map_err(|e| { + TraversalModelError::BuildError(format!( + "failed to deserialize configuration for transit traversal model: {e}" + )) + })?; + + let model = TransitTraversalModel::new( + self.engine.clone(), + model_query.start_datetime, + model_query.record_dwell_time.unwrap_or_default(), + ); + Ok(Arc::new(model)) + } +} diff --git a/rust/bambam-gtfs/src/model/traversal/transit/transit_ops.rs b/rust/bambam-gtfs/src/model/traversal/transit/transit_ops.rs new file mode 100644 index 00000000..0f51ff79 --- /dev/null +++ b/rust/bambam-gtfs/src/model/traversal/transit/transit_ops.rs @@ -0,0 +1,474 @@ +use std::{collections::HashMap, ops::Add}; + +use chrono::{Duration, NaiveDate, NaiveDateTime}; +use routee_compass_core::model::{ + state::{StateModel, StateVariable}, + traversal::TraversalModelError, +}; + +use crate::model::traversal::transit::Departure; +use bambam_core::model::state::fieldname; + +/// composes the start time and the current trip_time into a new datetime value. +pub fn get_current_time( + start_datetime: &NaiveDateTime, + state: &[StateVariable], + state_model: &StateModel, +) -> Result { + let trip_time = state_model + .get_time(state, fieldname::TRIP_TIME)? + .get::(); + let seconds = trip_time as i64; + let remainder = (trip_time - seconds as f64); + let nanos = (remainder * 1_000_000_000.0) as u32; + let trip_duration = Duration::new(seconds, nanos).ok_or_else(|| { + TraversalModelError::TraversalModelFailure(format!( + "unable to build Duration from seconds, nanos: {seconds}, {nanos}" + )) + })?; + + let current_datetime = start_datetime.checked_add_signed(trip_duration).ok_or( + TraversalModelError::InternalError(format!( + "Invalid Datetime from Date {start_datetime} + {trip_time} seconds" + )), + )?; + Ok(current_datetime) +} + +/// checks for any date mapping for the current date/time value and applies it if found. +pub fn apply_date_mapping( + date_mapping: &HashMap>, + route_id_label: &i64, + current_datetime: &NaiveDateTime, +) -> NaiveDateTime { + date_mapping + .get(route_id_label) + .and_then(|date_map| date_map.get(¤t_datetime.date())) + .unwrap_or(¤t_datetime.date()) + .and_time(current_datetime.time()) +} + +/// finds the difference in days between the current and the mapped date and uses that +/// difference to modify the departure time to make it relevant for this search. +pub fn reverse_date_mapping( + current_datetime: &NaiveDateTime, + mapped_datetime: &NaiveDateTime, + departure: Departure, +) -> Departure { + if departure.is_pos_infinity() { + return departure; + } + let diff = current_datetime.signed_duration_since(*mapped_datetime); + + departure + &diff +} + +#[cfg(test)] +mod tests { + use chrono::{Datelike, Duration, NaiveDateTime}; + use routee_compass_core::model::{ + state::{StateModel, StateVariable, StateVariableConfig}, + unit::TimeUnit, + }; + use uom::si::f64::Time; + + use bambam_core::model::state::fieldname; + + fn mock_state(time: Time, state_model: &StateModel) -> Vec { + let mut state = state_model + .initial_state(None) + .expect("test invariant failed: could not create initial state"); + state_model + .set_time(&mut state, fieldname::TRIP_TIME, &time) + .unwrap_or_else(|_| { + panic!( + "test invariant failed: could not set time value of {} for state", + time.value + ) + }); + state + } + + fn mock_state_model(time_unit: Option) -> StateModel { + let trip_time_config = StateVariableConfig::Time { + initial: Time::new::(0.0), + accumulator: true, + output_unit: time_unit, + }; + StateModel::new(vec![(fieldname::TRIP_TIME.to_string(), trip_time_config)]) + } + + #[test] + fn test_get_current_time_various_scenarios() { + use uom::si::time::second; + + let test_cases = vec![ + // (name, start_time, trip_seconds, expected_time, description) + ( + "basic_composition", + "2023-06-15 08:30:00", + 3600.0, + "2023-06-15 09:30:00", + "1 hour trip", + ), + ( + "fractional_seconds", + "2023-06-15 08:30:00", + 1800.5, + "2023-06-15 09:00:00.500000000", + "30min + 500ms", + ), + ( + "midnight_wrapping", + "2023-06-15 23:30:00", + 3600.0, + "2023-06-16 00:30:00", + "wrap to next day", + ), + ( + "zero_trip_time", + "2023-06-15 14:45:30", + 0.0, + "2023-06-15 14:45:30", + "no time elapsed", + ), + ( + "multi_day_journey", + "2023-06-15 12:00:00", + 259200.0, + "2023-06-18 12:00:00", + "3 days", + ), + ( + "year_boundary", + "2023-12-31 23:59:59", + 1.0, + "2024-01-01 00:00:00", + "cross year", + ), + ( + "non_leap_month", + "2023-02-28 23:30:00", + 1800.0, + "2023-03-01 00:00:00", + "Feb to Mar", + ), + ( + "leap_year", + "2024-02-28 23:30:00", + 1800.0, + "2024-02-29 00:00:00", + "leap year Feb", + ), + ]; + + for (name, start_str, trip_seconds, expected_str, description) in test_cases { + let start_datetime = NaiveDateTime::parse_from_str(start_str, "%Y-%m-%d %H:%M:%S") + .or_else(|_| NaiveDateTime::parse_from_str(start_str, "%Y-%m-%d %H:%M:%S%.f")) + .unwrap_or_else(|_| panic!("Failed to parse start datetime for {name}")); + + let expected = NaiveDateTime::parse_from_str(expected_str, "%Y-%m-%d %H:%M:%S") + .or_else(|_| NaiveDateTime::parse_from_str(expected_str, "%Y-%m-%d %H:%M:%S%.f")) + .unwrap_or_else(|_| panic!("Failed to parse expected datetime for {name}")); + + let state_model = mock_state_model(None); + let trip_time = Time::new::(trip_seconds); + let state = mock_state(trip_time, &state_model); + + let result = super::get_current_time(&start_datetime, &state, &state_model) + .unwrap_or_else(|_| panic!("{name} ({description}) should succeed")); + + assert_eq!(result, expected, "{name}: {description}"); + } + } + + #[test] + fn test_get_current_time_different_time_units() { + use uom::si::time::{hour, minute}; + + // Test with different TimeUnit configurations + let start_datetime = + NaiveDateTime::parse_from_str("2023-06-15 10:00:00", "%Y-%m-%d %H:%M:%S") + .expect("Failed to parse test datetime"); + + // Test with minute units + let state_model_minutes = mock_state_model(Some(TimeUnit::Minutes)); + let trip_time_minutes = Time::new::(30.0); // 30 minutes + let state_minutes = mock_state(trip_time_minutes, &state_model_minutes); + + let result_minutes = + super::get_current_time(&start_datetime, &state_minutes, &state_model_minutes) + .expect("get_current_time should succeed with minutes"); + + // Test with hour units + let state_model_hours = mock_state_model(Some(TimeUnit::Hours)); + let trip_time_hours = Time::new::(0.5); // 0.5 hours = 30 minutes + let state_hours = mock_state(trip_time_hours, &state_model_hours); + + let result_hours = + super::get_current_time(&start_datetime, &state_hours, &state_model_hours) + .expect("get_current_time should succeed with hours"); + + let expected = NaiveDateTime::parse_from_str("2023-06-15 10:30:00", "%Y-%m-%d %H:%M:%S") + .expect("Failed to parse expected datetime"); + + // Both should produce the same result + assert_eq!(result_minutes, expected); + assert_eq!(result_hours, expected); + assert_eq!(result_minutes, result_hours); + } + + #[test] + fn test_get_current_time_precise_fractional_composition() { + // Test precise fractional second handling + let start_datetime = + NaiveDateTime::parse_from_str("2023-06-15 15:20:10", "%Y-%m-%d %H:%M:%S") + .expect("Failed to parse test datetime"); + let state_model = mock_state_model(None); + let trip_time = Time::new::(125.123456789); // 2 minutes, 5.123456789 seconds + let state = mock_state(trip_time, &state_model); + + let result = super::get_current_time(&start_datetime, &state, &state_model) + .expect("get_current_time should succeed"); + + // Expected: 15:20:10 + 125.123456789s = 15:22:15.123456789 + // chrono handles nanosecond precision + let expected_seconds = 125i64; + let expected_nanos = 123_456_789u32; + let expected = + start_datetime + chrono::Duration::new(expected_seconds, expected_nanos).unwrap(); + + assert_eq!(result, expected); + } + + #[test] + fn test_get_current_time_error_cases() { + // Test error case: invalid duration construction + let start_datetime = + NaiveDateTime::parse_from_str("2023-06-15 12:00:00", "%Y-%m-%d %H:%M:%S") + .expect("Failed to parse test datetime"); + let state_model = mock_state_model(None); + + // Test with negative time (should be caught by Duration::new if invalid) + let trip_time = Time::new::(-1.0); + let state = mock_state(trip_time, &state_model); + + // This might succeed or fail depending on chrono's handling of negative durations + // The behavior should be consistent + let result = super::get_current_time(&start_datetime, &state, &state_model); + + // For negative values, we expect either success (if chrono handles it) or a specific error + match result { + Ok(_) => { + // If it succeeds, the result should be before the start time + assert!(result.unwrap() < start_datetime); + } + Err(e) => { + // Should be a specific error about duration or datetime construction + assert!(matches!(e, + routee_compass_core::model::traversal::TraversalModelError::TraversalModelFailure(_) | + routee_compass_core::model::traversal::TraversalModelError::InternalError(_) + )); + } + } + } + + #[test] + fn test_reverse_date_mapping_normal_cases() { + let test_cases = vec![ + // (name, current, search, departure_src, departure_dst, expected_src, expected_dst) + ( + "basic_positive_delay", + "2023-06-15 14:30:00", + "2023-06-15 10:00:00", + "2023-06-15 11:00:00", + "2023-06-15 11:30:00", + "2023-06-15 15:30:00", + "2023-06-15 16:00:00", + ), + ( + "search_after_current", + "2023-06-15 14:30:00", + "2023-06-20 10:00:00", + "2023-06-20 15:00:00", + "2023-06-20 15:30:00", + "2023-06-15 19:30:00", + "2023-06-15 20:00:00", + ), + ( + "negative_delay", + "2023-06-15 14:30:00", + "2023-06-15 12:00:00", + "2023-06-15 10:00:00", + "2023-06-15 10:30:00", + "2023-06-15 12:30:00", + "2023-06-15 13:00:00", + ), + ]; + + for (name, current_str, search_str, dep_src_str, dep_dst_str, exp_src_str, exp_dst_str) in + test_cases + { + let current_datetime = + NaiveDateTime::parse_from_str(current_str, "%Y-%m-%d %H:%M:%S").unwrap(); + let search_datetime = + NaiveDateTime::parse_from_str(search_str, "%Y-%m-%d %H:%M:%S").unwrap(); + let departure = super::Departure { + src_departure_time: NaiveDateTime::parse_from_str(dep_src_str, "%Y-%m-%d %H:%M:%S") + .unwrap(), + dst_arrival_time: NaiveDateTime::parse_from_str(dep_dst_str, "%Y-%m-%d %H:%M:%S") + .unwrap(), + }; + let expected_src = + NaiveDateTime::parse_from_str(exp_src_str, "%Y-%m-%d %H:%M:%S").unwrap(); + let expected_dst = + NaiveDateTime::parse_from_str(exp_dst_str, "%Y-%m-%d %H:%M:%S").unwrap(); + + let result = + super::reverse_date_mapping(¤t_datetime, &search_datetime, departure); + + assert_eq!( + result.src_departure_time, expected_src, + "{name}: src_departure_time" + ); + assert_eq!( + result.dst_arrival_time, expected_dst, + "{name}: dst_arrival_time" + ); + } + } + + #[test] + fn test_reverse_date_mapping_with_infinity_past_date() { + // Test that reverse_date_mapping correctly handles Departure::infinity() + let current_datetime = + NaiveDateTime::parse_from_str("2023-06-15 14:30:00", "%Y-%m-%d %H:%M:%S") + .expect("Failed to parse current datetime"); + let search_datetime = + NaiveDateTime::parse_from_str("2023-06-20 10:00:00", "%Y-%m-%d %H:%M:%S") + .expect("Failed to parse search datetime"); + + // Use Departure::infinity() which has NaiveDateTime::MAX for both times + let infinity_departure = super::Departure::infinity(); + + let result = + super::reverse_date_mapping(¤t_datetime, &search_datetime, infinity_departure); + + // With overflow protection, infinity should return infinity + assert_eq!(result, super::Departure::infinity()); + } + + #[test] + fn test_reverse_date_mapping_large_future_departure() { + // Test with a departure far in the future + let current_datetime = + NaiveDateTime::parse_from_str("2023-06-15 14:30:00", "%Y-%m-%d %H:%M:%S") + .expect("Failed to parse current datetime"); + let search_datetime = + NaiveDateTime::parse_from_str("2023-06-15 10:00:00", "%Y-%m-%d %H:%M:%S") + .expect("Failed to parse search datetime"); + + // Create a departure very far in the future (year 9999) + let departure = super::Departure { + src_departure_time: NaiveDateTime::parse_from_str( + "9999-12-31 23:59:59", + "%Y-%m-%d %H:%M:%S", + ) + .unwrap(), + dst_arrival_time: NaiveDateTime::parse_from_str( + "9999-12-31 23:59:59", + "%Y-%m-%d %H:%M:%S", + ) + .unwrap(), + }; + + let result = super::reverse_date_mapping(¤t_datetime, &search_datetime, departure); + + // The delay is ~7976 years, adding to 2023 gives ~10000 + // This doesn't overflow, it produces a valid far-future date + assert!( + result.src_departure_time.year() >= 10000, + "Expected year 10000+, got {}", + result.src_departure_time.year() + ); + assert_eq!(result.src_departure_time, result.dst_arrival_time); + } + + #[test] + fn test_reverse_date_mapping_search_after_current_with_large_departure() { + // Test case 3 with extreme values that would cause overflow without protection + let current_datetime = + NaiveDateTime::parse_from_str("2023-06-15 14:30:00", "%Y-%m-%d %H:%M:%S") + .expect("Failed to parse current datetime"); + let search_datetime = + NaiveDateTime::parse_from_str("2020-01-01 00:00:00", "%Y-%m-%d %H:%M:%S") + .expect("Failed to parse search datetime"); + + // Departure is very far in the future relative to search + let departure = super::Departure { + src_departure_time: NaiveDateTime::parse_from_str( + "9999-12-31 23:59:59", + "%Y-%m-%d %H:%M:%S", + ) + .unwrap(), + dst_arrival_time: NaiveDateTime::parse_from_str( + "9999-12-31 23:59:59", + "%Y-%m-%d %H:%M:%S", + ) + .unwrap(), + }; + + let result = super::reverse_date_mapping(¤t_datetime, &search_datetime, departure); + + // The delay from search to departure is ~8000 years + // Adding this to current_datetime (2023) gives ~10023 + // This doesn't overflow, but produces a very far future date + assert!( + result.src_departure_time.year() >= 10000, + "Expected year 10000+, got {}", + result.src_departure_time.year() + ); + assert_eq!(result.src_departure_time, result.dst_arrival_time); + } + + #[test] + fn test_reverse_date_mapping_overflow_to_max() { + // Test with values that will actually cause overflow and clamp to MAX + // Construct a far-future date by adding duration to a base date + let base_date = NaiveDateTime::parse_from_str("2020-01-01 00:00:00", "%Y-%m-%d %H:%M:%S") + .expect("Failed to parse base datetime"); + + // Add 50,000 years worth of days + let current_datetime = base_date + .checked_add_signed(Duration::days(50000 * 365)) + .expect("Failed to create far-future current datetime"); + + let search_datetime = base_date; + + // Create a departure close to MAX to ensure overflow + // MAX is year +262142 + let near_max = NaiveDateTime::MAX + .checked_sub_signed(Duration::days(365)) + .unwrap(); + let departure = super::Departure { + src_departure_time: near_max, + dst_arrival_time: near_max, + }; + + let result = super::reverse_date_mapping(¤t_datetime, &search_datetime, departure); + + // The delay from base_date to near-MAX is ~260,000 years + // Adding it to current_datetime (50,000 years in future) will overflow + // This should clamp to MAX + assert_eq!( + result.src_departure_time, + NaiveDateTime::MAX, + "Should clamp to MAX on overflow" + ); + assert_eq!( + result.dst_arrival_time, + NaiveDateTime::MAX, + "Should clamp to MAX on overflow" + ); + } +} diff --git a/rust/bambam-gtfs/src/schedule/app/operation.rs b/rust/bambam-gtfs/src/schedule/app/operation.rs index aebb0287..1ba191f8 100644 --- a/rust/bambam-gtfs/src/schedule/app/operation.rs +++ b/rust/bambam-gtfs/src/schedule/app/operation.rs @@ -244,7 +244,7 @@ fn load_vertices_and_create_spatial_index( Some(bar_builder), None, ) - .map_err(|e| ScheduleError::FailedToCreateVertexIndexError(format!("{e}")))?; + .map_err(|e| ScheduleError::FailedToCreateVertexIndex(format!("{e}")))?; let tol: Length = uom::si::f64::Length::new::(tolerance_meters); Ok(Arc::new(SpatialIndex::new_vertex_oriented( &vertices, @@ -269,14 +269,12 @@ fn manifest_into_rows( .from_path(path_buf.as_path()) .map_err(|e| { let filename = path_buf.to_str().unwrap_or_default(); - ScheduleError::GtfsAppError(format!("failure reading '{filename}': {e}")) + ScheduleError::GtfsApp(format!("failure reading '{filename}': {e}")) })?; let rows = reader .into_deserialize::() .map(|r| { - r.map_err(|e| { - ScheduleError::GtfsAppError(format!("failure reading GTFS manifest row: {e}")) - }) + r.map_err(|e| ScheduleError::GtfsApp(format!("failure reading GTFS manifest row: {e}"))) }) .collect::, ScheduleError>>()?; let us_rows: Vec = rows diff --git a/rust/bambam-gtfs/src/schedule/bundle_ops.rs b/rust/bambam-gtfs/src/schedule/bundle_ops.rs index 962a5c53..d2e50825 100644 --- a/rust/bambam-gtfs/src/schedule/bundle_ops.rs +++ b/rust/bambam-gtfs/src/schedule/bundle_ops.rs @@ -72,9 +72,9 @@ pub fn batch_process( ) -> Result<(), ScheduleError> { let archive_paths = bundle_directory_path .read_dir() - .map_err(|e| ScheduleError::GtfsAppError(format!("failure reading directory: {e}")))? + .map_err(|e| ScheduleError::GtfsApp(format!("failure reading directory: {e}")))? .collect::, _>>() - .map_err(|e| ScheduleError::GtfsAppError(format!("failure reading directory: {e}")))?; + .map_err(|e| ScheduleError::GtfsApp(format!("failure reading directory: {e}")))?; let chunk_size = archive_paths.len() / std::cmp::max(1, parallelism); // a progress bar shared across threads @@ -84,9 +84,7 @@ pub fn batch_process( .total(archive_paths.len()) .animation("fillup") .build() - .map_err(|e| { - ScheduleError::InternalError(format!("failure building progress bar: {e}")) - })?, + .map_err(|e| ScheduleError::Internal(format!("failure building progress bar: {e}")))?, )); let (bundles, errors): (Vec, Vec) = archive_paths @@ -100,13 +98,13 @@ pub fn batch_process( } let path = dir_entry.path(); let bundle_file = path.to_str().ok_or_else(|| { - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "unable to convert directory entry into string: {dir_entry:?}" )) })?; // let edge_list_id = *start_edge_list_id + edge_list_offset; process_bundle(bundle_file, conf.clone()).map_err(|e| { - ScheduleError::GtfsAppError(format!("while processing {bundle_file}, {e}")) + ScheduleError::GtfsApp(format!("while processing {bundle_file}, {e}")) }) }) .collect_vec() @@ -205,7 +203,7 @@ pub fn process_bundle( .pick_date(&target_date, &trip, gtfs.clone())?; if target_date != picked_date { let route = gtfs.get_route(&trip.route_id).map_err(|_| { - ScheduleError::MalformedGtfsError(format!( + ScheduleError::MalformedGtfs(format!( "trip {} references route id {} that does not exist", trip.trip_id, trip.route_id )) @@ -270,7 +268,7 @@ pub fn write_bundle( let metadata_filename = format!("edges-gtfs-metadata-{edge_list_id}.json"); std::fs::create_dir_all(output_directory).map_err(|e| { let outdir = output_directory.to_str().unwrap_or_default(); - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "unable to create output directory path '{outdir}': {e}" )) })?; @@ -283,11 +281,10 @@ pub fn write_bundle( metadata["fq_route_ids"] = json![fq_route_ids]; let metadata_str = serde_json::to_string_pretty(&metadata).map_err(|e| { - ScheduleError::GtfsAppError(format!("failure writing GTFS Agencies as JSON string: {e}")) - })?; - std::fs::write(output_directory.join(metadata_filename), &metadata_str).map_err(|e| { - ScheduleError::GtfsAppError(format!("failed writing GTFS Agency metadata: {e}")) + ScheduleError::GtfsApp(format!("failure writing GTFS Agencies as JSON string: {e}")) })?; + std::fs::write(output_directory.join(metadata_filename), &metadata_str) + .map_err(|e| ScheduleError::GtfsApp(format!("failed writing GTFS Agency metadata: {e}")))?; let edges_filename = format!("edges-compass-{edge_list_id}.csv.gz"); let schedules_filename = format!("edges-schedules-{edge_list_id}.csv.gz"); let geometries_filename = format!("edges-geometries-enumerated-{edge_list_id}.txt.gz"); @@ -321,7 +318,7 @@ pub fn write_bundle( { if let Some(ref mut writer) = edges_writer { writer.serialize(edge).map_err(|e| { - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "Failed to write to edges file {}: {}", String::from(&edges_filename), e @@ -333,7 +330,7 @@ pub fn write_bundle( for schedule in schedules.iter() { let fq_schedule = FullyQualifiedScheduleRow::new(schedule, edge_list_id); writer.serialize(fq_schedule).map_err(|e| { - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "Failed to write to schedules file {}: {}", String::from(&schedules_filename), e @@ -346,7 +343,7 @@ pub fn write_bundle( writer .serialize(geometry.to_wkt().to_string()) .map_err(|e| { - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "Failed to write to geometry file {}: {}", String::from(&edges_filename), e @@ -364,6 +361,7 @@ pub fn write_bundle( /// - creates a [GtfsEdge] if one does not yet exist between these vertices. /// - handles presence of src + dst times and constructs the datetimes to write to our schedule row /// - adds this schedule row to our GtfsEdge +#[allow(clippy::too_many_arguments)] fn process_schedule( picked_date: &NaiveDate, src: &StopTime, @@ -387,7 +385,7 @@ fn process_schedule( (Some(result), _) => result, (None, MissingStopLocationPolicy::Fail) => { let msg = format!("{} or {}", src.stop.id, dst.stop.id); - return Err(ScheduleError::MissingStopLocationAndParentError(msg)); + return Err(ScheduleError::MissingStopLocationAndParent(msg)); } (None, MissingStopLocationPolicy::DropStop) => return Ok(None), }; @@ -433,11 +431,11 @@ fn process_schedule( let arr_dt = create_datetime(arr, picked_date)?; Ok((dep_dt, arr_dt)) } - (None, None) => Err(ScheduleError::MissingAllStopTimesError(src.stop.id.clone())), + (None, None) => Err(ScheduleError::MissingAllStopTimes(src.stop.id.clone())), }?; let route = gtfs.routes.get(&trip.route_id).ok_or_else(|| { - ScheduleError::MalformedGtfsError(format!( + ScheduleError::MalformedGtfs(format!( "trip {} has route id {} which is missing from the archive", trip.trip_id, trip.route_id )) @@ -470,7 +468,7 @@ fn create_datetime(gtfs_time: u32, date: &NaiveDate) -> Result Result, ScheduleError> { // Since `stop_locations` is computed from `gtfs.stops`, this should never fail let maybe_src = stop_locations.get(&src.stop.id).ok_or_else(|| { - ScheduleError::MalformedGtfsError(format!( + ScheduleError::MalformedGtfs(format!( "source stop_id '{}' is not associated with a geographic location in either it's stop row or any parent row (see 'parent_station' of GTFS Stops.txt)", src.stop.id )) })?; let maybe_dst = stop_locations.get(&dst.stop.id).ok_or_else(|| { - ScheduleError::MalformedGtfsError(format!( + ScheduleError::MalformedGtfs(format!( "destination stop_id '{}' is not associated with a geographic location in either it's stop row or any parent row (see 'parent_station' of GTFS Stops.txt)", dst.stop.id )) @@ -594,7 +592,7 @@ fn match_closest_graph_id( let nearest_result = spatial_index.nearest_graph_id(&point_f32)?; match nearest_result { NearestSearchResult::NearestVertex(vertex_id) => Ok(vertex_id), - _ => Err(ScheduleError::GtfsAppError(format!( + _ => Err(ScheduleError::GtfsApp(format!( "could not find matching vertex for point {} in spatial index. consider expanding the distance tolerance or allowing for stop filtering.", point.to_wkt() ))), diff --git a/rust/bambam-gtfs/src/schedule/date/date_ops.rs b/rust/bambam-gtfs/src/schedule/date/date_ops.rs index 6fe67e2e..81a6ce17 100644 --- a/rust/bambam-gtfs/src/schedule/date/date_ops.rs +++ b/rust/bambam-gtfs/src/schedule/date/date_ops.rs @@ -79,7 +79,7 @@ pub fn find_in_calendar( Ok(*target) } else { let msg = error_msg_suffix(target, start, end); - Err(ScheduleError::InvalidDataError(format!( + Err(ScheduleError::InvalidData(format!( "no calendar.txt dates match {msg}" ))) } @@ -101,7 +101,7 @@ pub fn confirm_add_exception( "no calendar_dates match target date '{}' with exception_type as 'added'", target.format(APP_DATE_FORMAT), ); - Err(ScheduleError::InvalidDataError(msg)) + Err(ScheduleError::InvalidData(msg)) } } } @@ -154,7 +154,7 @@ pub fn find_nearest_add_exception( target.format(APP_DATE_FORMAT), mwd_str ); - Err(ScheduleError::InvalidDataError(msg)) + Err(ScheduleError::InvalidData(msg)) } } } @@ -179,7 +179,7 @@ pub fn step_date(date: NaiveDate, step: i64) -> Result step, date.format(APP_DATE_FORMAT) ); - ScheduleError::InvalidDataError(msg) + ScheduleError::InvalidData(msg) }) } @@ -228,7 +228,7 @@ mod tests { let result = step_date(date, 1); assert!(result.is_err()); - if let Err(ScheduleError::InvalidDataError(msg)) = result { + if let Err(ScheduleError::InvalidData(msg)) = result { assert!(msg.contains("failure adding")); assert!(msg.contains("bounds error")); } else { @@ -243,7 +243,7 @@ mod tests { let result = step_date(date, -1); assert!(result.is_err()); - if let Err(ScheduleError::InvalidDataError(msg)) = result { + if let Err(ScheduleError::InvalidData(msg)) = result { assert!(msg.contains("failure subtracting")); assert!(msg.contains("bounds error")); } else { @@ -288,7 +288,7 @@ mod tests { let result = confirm_add_exception(&target, &calendar_dates); assert!(result.is_err()); - if let Err(ScheduleError::InvalidDataError(msg)) = result { + if let Err(ScheduleError::InvalidData(msg)) = result { assert!(msg.contains("no calendar_dates match target date")); assert!(msg.contains("06-15-2023")); // MM-DD-YYYY format assert!(msg.contains("exception_type as 'added'")); @@ -478,7 +478,7 @@ mod tests { let result = find_nearest_add_exception(&target, &calendar_dates, 5, false); assert!(result.is_err()); - if let Err(ScheduleError::InvalidDataError(msg)) = result { + if let Err(ScheduleError::InvalidData(msg)) = result { assert!(msg.contains("no Added entry in calendar_dates.txt")); assert!(msg.contains("within 5 days")); assert!(msg.contains("06-15-2023")); // MM-DD-YYYY format @@ -507,7 +507,7 @@ mod tests { let result = find_nearest_add_exception(&target, &calendar_dates, 5, true); assert!(result.is_err()); - if let Err(ScheduleError::InvalidDataError(msg)) = result { + if let Err(ScheduleError::InvalidData(msg)) = result { assert!(msg.contains("no Added entry in calendar_dates.txt")); assert!(msg.contains("with matching weekday")); } else { @@ -643,7 +643,7 @@ mod tests { let result = find_in_calendar(&target, &calendar); assert!(result.is_err()); - if let Err(ScheduleError::InvalidDataError(msg)) = result { + if let Err(ScheduleError::InvalidData(msg)) = result { assert!(msg.contains("no calendar.txt dates match")); assert!(msg.contains("06-05-2023")); assert!(msg.contains("[06-10-2023,06-30-2023]")); @@ -661,7 +661,7 @@ mod tests { let result = find_in_calendar(&target, &calendar); assert!(result.is_err()); - if let Err(ScheduleError::InvalidDataError(msg)) = result { + if let Err(ScheduleError::InvalidData(msg)) = result { assert!(msg.contains("no calendar.txt dates match")); assert!(msg.contains("06-25-2023")); assert!(msg.contains("[06-01-2023,06-20-2023]")); diff --git a/rust/bambam-gtfs/src/schedule/date_mapping_policy.rs b/rust/bambam-gtfs/src/schedule/date_mapping_policy.rs index a51c9de5..791075d8 100644 --- a/rust/bambam-gtfs/src/schedule/date_mapping_policy.rs +++ b/rust/bambam-gtfs/src/schedule/date_mapping_policy.rs @@ -75,7 +75,7 @@ impl TryFrom<&DateMappingPolicyConfig> for DateMappingPolicy { match value { DateMappingPolicyConfig::ExactDate(date_str) => { let date = NaiveDate::parse_from_str(date_str, APP_DATE_FORMAT).map_err(|e| { - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "failure reading date for exact date mapping policy: {e}" )) })?; @@ -87,13 +87,13 @@ impl TryFrom<&DateMappingPolicyConfig> for DateMappingPolicy { } => { let start_date = NaiveDate::parse_from_str(start_date, APP_DATE_FORMAT).map_err(|e| { - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "failure reading start_date for exact range mapping policy: {e}" )) })?; let end_date = NaiveDate::parse_from_str(end_date, APP_DATE_FORMAT).map_err(|e| { - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "failure reading end_date for exact range mapping policy: {e}" )) })?; @@ -108,7 +108,7 @@ impl TryFrom<&DateMappingPolicyConfig> for DateMappingPolicy { match_weekday, } => { let date = NaiveDate::parse_from_str(date, APP_DATE_FORMAT).map_err(|e| { - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "failure reading date for nearest date mapping policy: {e}" )) })?; @@ -126,13 +126,13 @@ impl TryFrom<&DateMappingPolicyConfig> for DateMappingPolicy { } => { let start_date = NaiveDate::parse_from_str(start_date, APP_DATE_FORMAT).map_err(|e| { - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "failure reading start_date for nearest range mapping policy: {e}" )) })?; let end_date = NaiveDate::parse_from_str(end_date, APP_DATE_FORMAT).map_err(|e| { - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "failure reading end_date for nearest range mapping policy: {e}" )) })?; @@ -151,25 +151,25 @@ impl TryFrom<&DateMappingPolicyConfig> for DateMappingPolicy { } => { let start_date = NaiveDate::parse_from_str(start_date, APP_DATE_FORMAT).map_err(|e| { - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "failure reading start_date for exact date time range mapping policy: {e}" )) })?; let end_date = NaiveDate::parse_from_str(end_date, APP_DATE_FORMAT).map_err(|e| { - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "failure reading end_date for exact date time range mapping policy: {e}" )) })?; let start_time = NaiveTime::parse_from_str(start_time, APP_TIME_FORMAT).map_err(|e| { - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "failure reading start_time for exact date time range mapping policy: {e}" )) })?; let end_time = NaiveTime::parse_from_str(end_time, APP_TIME_FORMAT).map_err(|e| { - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "failure reading end_time for exact date time range mapping policy: {e}" )) })?; @@ -190,25 +190,25 @@ impl TryFrom<&DateMappingPolicyConfig> for DateMappingPolicy { } => { let start_date = NaiveDate::parse_from_str(start_date, APP_DATE_FORMAT) .map_err(|e| { - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "failure reading start_date for nearest date time range mapping policy: {e}" )) })?; let end_date = NaiveDate::parse_from_str(end_date, APP_DATE_FORMAT) .map_err(|e| { - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "failure reading end_date for nearest date time range mapping policy: {e}" )) })?; let start_time = NaiveTime::parse_from_str(start_time, APP_TIME_FORMAT).map_err(|e| { - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "failure reading start_time for nearest date time range mapping policy: {e}" )) })?; let end_time = NaiveTime::parse_from_str(end_time, APP_TIME_FORMAT).map_err(|e| { - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "failure reading end_time for nearest date time range mapping policy: {e}" )) })?; @@ -230,25 +230,25 @@ impl TryFrom<&DateMappingPolicyConfig> for DateMappingPolicy { } => { let start_date = NaiveDate::parse_from_str(start_date, APP_DATE_FORMAT).map_err(|e| { - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "failure reading start_date for best case mapping policy: {e}" )) })?; let end_date = NaiveDate::parse_from_str(end_date, APP_DATE_FORMAT).map_err(|e| { - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "failure reading end_date for best case mapping policy: {e}" )) })?; let start_time = NaiveTime::parse_from_str(start_time, APP_TIME_FORMAT).map_err(|e| { - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "failure reading start_time for best case mapping policy: {e}" )) })?; let end_time = NaiveTime::parse_from_str(end_time, APP_TIME_FORMAT).map_err(|e| { - ScheduleError::GtfsAppError(format!( + ScheduleError::GtfsApp(format!( "failure reading end_time for best case mapping policy: {e}" )) })?; @@ -347,7 +347,7 @@ impl DateMappingPolicy { format!("While attempting to pick nearest date within {date_tolerance} days matching weekday: {e2}."), format!("While attempting to pick nearest date within {date_tolerance} days without matching weekday: {e3}.") ].join(" "); - Err(ScheduleError::InvalidDataError(msg)) + Err(ScheduleError::InvalidData(msg)) } } } @@ -394,7 +394,7 @@ fn pick_exact_date( match (c_opt, cd_opt) { (None, None) => { let msg = format!("cannot pick date with trip_id '{}' as it does not match calendar or calendar dates", trip.trip_id); - Err(ScheduleError::MalformedGtfsError(msg)) + Err(ScheduleError::MalformedGtfs(msg)) } (Some(c), None) => date_ops::find_in_calendar(target, c), (None, Some(cd)) => date_ops::confirm_add_exception(target, cd), @@ -403,14 +403,14 @@ fn pick_exact_date( if date_ops::confirm_no_delete_exception(target, cd) { Ok(*target) } else { - Err(ScheduleError::InvalidDataError(format!( + Err(ScheduleError::InvalidData(format!( "date {} is valid for calendar.txt but has exception of deleted in calendar_dates.txt", target.format(APP_DATE_FORMAT) ))) } } Err(ce) => date_ops::confirm_add_exception(target, cd) - .map_err(|e| ScheduleError::InvalidDataError(format!("{ce}, {e}"))), + .map_err(|e| ScheduleError::InvalidData(format!("{ce}, {e}"))), }, } } @@ -429,7 +429,7 @@ fn pick_nearest_date( match (c_opt, cd_opt) { (None, None) => { let msg = format!("cannot pick date with trip_id '{}' as it does not match calendar or calendar dates", trip.trip_id); - Err(ScheduleError::MalformedGtfsError(msg)) + Err(ScheduleError::MalformedGtfs(msg)) } (None, Some(cd)) => { date_ops::find_nearest_add_exception(target, cd, date_tolerance, match_weekday) @@ -444,7 +444,7 @@ fn pick_nearest_date( )?; matches.first().cloned().ok_or_else(|| { let msg = date_ops::error_msg_suffix(target, &c.start_date, &c.end_date); - ScheduleError::InvalidDataError(format!( + ScheduleError::InvalidData(format!( "could not find nearest (by {date_tolerance} days) date {msg}" )) }) @@ -485,7 +485,7 @@ fn pick_nearest_date( .cloned(); min_distance_match.ok_or_else(|| { - ScheduleError::InvalidDataError(format!( + ScheduleError::InvalidData(format!( "no match found across calendar + calendar_dates {}", date_ops::error_msg_suffix(target, &c.start_date, &c.end_date) )) diff --git a/rust/bambam-gtfs/src/schedule/date_mapping_policy_config.rs b/rust/bambam-gtfs/src/schedule/date_mapping_policy_config.rs index ee10e79b..d56a9b68 100644 --- a/rust/bambam-gtfs/src/schedule/date_mapping_policy_config.rs +++ b/rust/bambam-gtfs/src/schedule/date_mapping_policy_config.rs @@ -78,8 +78,8 @@ pub enum DateMappingPolicyConfig { impl DateMappingPolicyConfig { /// build a new [`DateMappingPolicy`] configuration from CLI arguments. pub fn new( - start_date: &String, - end_date: &String, + start_date: &str, + end_date: &str, start_time: Option<&String>, end_time: Option<&String>, date_mapping_policy: &DateMappingPolicyType, @@ -89,77 +89,77 @@ impl DateMappingPolicyConfig { use DateMappingPolicyConfig as Config; use DateMappingPolicyType as Type; match date_mapping_policy { - Type::ExactDate => Ok(Config::ExactDate(start_date.clone())), + Type::ExactDate => Ok(Config::ExactDate(start_date.to_string())), Type::ExactRange => Ok(Config::ExactDateRange { - start_date: start_date.clone(), - end_date: end_date.clone(), + start_date: start_date.to_string(), + end_date: end_date.to_string(), }), Type::NearestDate => { - let match_weekday = date_mapping_match_weekday.ok_or_else(|| ScheduleError::GtfsAppError(String::from("for nearest-date mapping, must specify 'match_weekday' as 'true' or 'false'")))?; + let match_weekday = date_mapping_match_weekday.ok_or_else(|| ScheduleError::GtfsApp(String::from("for nearest-date mapping, must specify 'match_weekday' as 'true' or 'false'")))?; let date_tolerance = date_mapping_date_tolerance.ok_or_else(|| { - ScheduleError::GtfsAppError(String::from( + ScheduleError::GtfsApp(String::from( "for nearest-date mapping, must specify a date_tolerance in [0, inf)", )) })?; Ok(Self::NearestDate { - date: start_date.clone(), + date: start_date.to_string(), date_tolerance, match_weekday, }) } Type::NearestRange => { - let match_weekday = date_mapping_match_weekday.ok_or_else(|| ScheduleError::GtfsAppError(String::from("for nearest-date mapping, must specify 'match_weekday' as 'true' or 'false'")))?; + let match_weekday = date_mapping_match_weekday.ok_or_else(|| ScheduleError::GtfsApp(String::from("for nearest-date mapping, must specify 'match_weekday' as 'true' or 'false'")))?; let date_tolerance = date_mapping_date_tolerance.ok_or_else(|| { - ScheduleError::GtfsAppError(String::from( + ScheduleError::GtfsApp(String::from( "for nearest-date mapping, must specify a date_tolerance in [0, inf)", )) })?; Ok(Self::NearestDateRange { - start_date: start_date.clone(), - end_date: end_date.clone(), + start_date: start_date.to_string(), + end_date: end_date.to_string(), date_tolerance, match_weekday, }) } Type::ExactDateTimeRange => { let start_time = start_time.cloned().ok_or_else(|| { - ScheduleError::GtfsAppError(String::from( + ScheduleError::GtfsApp(String::from( "must provide start_time for exact date time range policy", )) })?; let end_time = end_time.cloned().ok_or_else(|| { - ScheduleError::GtfsAppError(String::from( + ScheduleError::GtfsApp(String::from( "must provide end_time for exact date time range policy", )) })?; Ok(Config::ExactDateTimeRange { - start_date: start_date.clone(), - end_date: end_date.clone(), + start_date: start_date.to_string(), + end_date: end_date.to_string(), start_time, end_time, }) } Type::NearestDateTimeRange => { let start_time = start_time.cloned().ok_or_else(|| { - ScheduleError::GtfsAppError(String::from( + ScheduleError::GtfsApp(String::from( "must provide start_time for nearest date time range policy", )) })?; let end_time = end_time.cloned().ok_or_else(|| { - ScheduleError::GtfsAppError(String::from( + ScheduleError::GtfsApp(String::from( "must provide end_time for nearest date time range policy", )) })?; - let match_weekday = date_mapping_match_weekday.ok_or_else(|| ScheduleError::GtfsAppError(String::from("for nearest-date mapping, must specify 'match_weekday' as 'true' or 'false'")))?; + let match_weekday = date_mapping_match_weekday.ok_or_else(|| ScheduleError::GtfsApp(String::from("for nearest-date mapping, must specify 'match_weekday' as 'true' or 'false'")))?; let date_tolerance = date_mapping_date_tolerance.ok_or_else(|| { - ScheduleError::GtfsAppError(String::from( + ScheduleError::GtfsApp(String::from( "for nearest-date mapping, must specify a date_tolerance in [0, inf)", )) })?; Ok(Self::NearestDateTimeRange { - start_date: start_date.clone(), - end_date: end_date.clone(), + start_date: start_date.to_string(), + end_date: end_date.to_string(), start_time, end_time, date_tolerance, @@ -168,18 +168,18 @@ impl DateMappingPolicyConfig { } Type::BestCase => { let start_time = start_time.cloned().ok_or_else(|| { - ScheduleError::GtfsAppError(String::from( + ScheduleError::GtfsApp(String::from( "must provide start_time for best case policy", )) })?; let end_time = end_time.cloned().ok_or_else(|| { - ScheduleError::GtfsAppError(String::from( + ScheduleError::GtfsApp(String::from( "must provide end_time for best case policy", )) })?; Ok(Self::BestCase { - start_date: start_date.clone(), - end_date: end_date.clone(), + start_date: start_date.to_string(), + end_date: end_date.to_string(), start_time, end_time, date_tolerance: date_mapping_date_tolerance, diff --git a/rust/bambam-gtfs/src/schedule/schedule_error.rs b/rust/bambam-gtfs/src/schedule/schedule_error.rs index 400c8ffc..d3615f61 100644 --- a/rust/bambam-gtfs/src/schedule/schedule_error.rs +++ b/rust/bambam-gtfs/src/schedule/schedule_error.rs @@ -4,41 +4,41 @@ use routee_compass_core::model::map::MapError; #[derive(thiserror::Error, Debug)] pub enum ScheduleError { #[error("Failed to parse gtfs bundle file into `Gtfs` struct: {0}")] - BundleReadError(#[from] gtfs_structures::Error), // { source: gtfs_structures::Error }, + BundleRead(#[from] gtfs_structures::Error), // { source: gtfs_structures::Error }, #[error("failure running bambam_gtfs: {0}")] - GtfsAppError(String), + GtfsApp(String), #[error("Failed to match point with spatial index: {source}")] - SpatialIndexMapError { + SpatialIndexMap { #[from] source: MapError, }, #[error("Spatial index matched an edge instead of a vertex")] - SpatialIndexIncorrectMapError, + SpatialIndexIncorrectMap, #[error("Missing lon,lat data and parent_location for stop: {0}")] - MissingStopLocationAndParentError(String), + MissingStopLocationAndParent(String), #[error("Missing both arrival and departure times: {0}")] - MissingAllStopTimesError(String), + MissingAllStopTimes(String), #[error("At least one of the stops in edge is missing shape distance traveled: {0} or {1}")] - MissingShapeDistanceTraveledError(String, String), + MissingShapeDistanceTraveled(String, String), #[error("Failed to create vertex index: {0}")] - FailedToCreateVertexIndexError(String), + FailedToCreateVertexIndex(String), #[error("Cannot find service in calendar.txt with service_id: {0}")] - InvalidCalendarError(String), + InvalidCalendar(String), #[error("Cannot find service in calendar_dates.txt with service_id: {0}")] - InvalidCalendarDatesError(String), + InvalidCalendarDates(String), #[error("Invalid Edges and schedules keys")] - InvalidResultKeysError, + InvalidResultKeys, #[error("error due to dataset contents: {0}")] - InvalidDataError(String), + InvalidData(String), #[error("GTFS archive is malformed: {0}")] - MalformedGtfsError(String), + MalformedGtfs(String), #[error("Internal Error: {0}")] - InternalError(String), + Internal(String), #[error("errors encountered during batch bundle processing: {0}")] - BatchProcessingError(String), + BatchProcessing(String), } pub fn batch_processing_error(errors: &[ScheduleError]) -> ScheduleError { let concatenated = errors.iter().map(|e| e.to_string()).join("\n "); - ScheduleError::BatchProcessingError(format!("[\n {concatenated}\n]")) + ScheduleError::BatchProcessing(format!("[\n {concatenated}\n]")) } diff --git a/rust/bambam-gtfs/src/schedule/sorted_trip.rs b/rust/bambam-gtfs/src/schedule/sorted_trip.rs index 992ced1e..932c2026 100644 --- a/rust/bambam-gtfs/src/schedule/sorted_trip.rs +++ b/rust/bambam-gtfs/src/schedule/sorted_trip.rs @@ -54,7 +54,7 @@ fn get_ordered_stops(trip: &Trip) -> Result, ScheduleError> { .map(|(_, idx)| { trip.stop_times.get(*idx).cloned().ok_or_else(|| { let msg = format!("expected stop index {idx} not found in trip {}", trip.id); - ScheduleError::MalformedGtfsError(msg) + ScheduleError::MalformedGtfs(msg) }) }) .collect::, _>>() diff --git a/rust/bambam-omf/src/collection/record/transportation_segment.rs b/rust/bambam-omf/src/collection/record/transportation_segment.rs index dcd1d565..6bb4ce54 100644 --- a/rust/bambam-omf/src/collection/record/transportation_segment.rs +++ b/rust/bambam-omf/src/collection/record/transportation_segment.rs @@ -319,6 +319,7 @@ pub enum SegmentRoadFlags { #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)] #[serde(rename_all = "snake_case")] +#[allow(clippy::enum_variant_names)] pub enum SegmentRailFlags { IsBridge, IsTunnel, diff --git a/rust/bambam-omf/src/graph/segment_split.rs b/rust/bambam-omf/src/graph/segment_split.rs index dbcf2588..486520ab 100644 --- a/rust/bambam-omf/src/graph/segment_split.rs +++ b/rust/bambam-omf/src/graph/segment_split.rs @@ -34,7 +34,6 @@ impl SegmentSplit { } /// identifies any locations where additional coordinates are needed. - /// when creating any missing connectors, call [ConnectorInSegment::new_without_connector_id] /// which generates a new connector_id based on the segment_id and linear referencing position. pub fn missing_connectors(&self) -> Vec { @@ -217,13 +216,13 @@ impl SegmentSplit { // retain speed limits with no heading or with a matching heading let speed_limits_with_heading = speed_limits .iter() - .filter_map(|s| match s.when.as_ref() { + .filter(|s| match s.when.as_ref() { Some(access) => match access.heading.as_ref() { - None => Some(s), - Some(h) if h == heading => Some(s), - _ => None, + None => true, + Some(h) if h == heading => true, + _ => false, }, - None => None, + None => false, }) .collect_vec(); @@ -290,12 +289,12 @@ impl SegmentSplit { // This ignores errors in `check_open_intersection` coming from invalid between values let opt_first_matching_sublcass = segment.subclass_rules.as_ref().and_then(|rules| { - rules.iter().find_map(|rule| { - match rule.check_open_intersection(start, end) { - Ok(true) => Some(rule), - _ => None, - } - }) + rules + .iter() + .find(|rule| match rule.check_open_intersection(start, end) { + Ok(true) => true, + _ => false, + }) }); // Get value from inside diff --git a/rust/bambam-osm/src/algorithm/buffer.rs b/rust/bambam-osm/src/algorithm/buffer.rs index 4fb5b18a..c5f13b0f 100644 --- a/rust/bambam-osm/src/algorithm/buffer.rs +++ b/rust/bambam-osm/src/algorithm/buffer.rs @@ -53,7 +53,7 @@ impl Buffer for Point { /// * `point` - a WGS84 point with x,y ordering /// * `radius` - buffer size, in meters /// * `resolution` - number of evenly-space points to place along the new buffer cirumference. -/// more points make a better approximation of a circle. +/// more points make a better approximation of a circle. /// /// * Returns /// diff --git a/rust/bambam-osm/src/algorithm/connected_components.rs b/rust/bambam-osm/src/algorithm/connected_components.rs index dd1762bb..b0d26899 100644 --- a/rust/bambam-osm/src/algorithm/connected_components.rs +++ b/rust/bambam-osm/src/algorithm/connected_components.rs @@ -47,7 +47,7 @@ fn add_undirected_edge(src: &OsmNodeId, dst: &OsmNodeId, g: &mut UndirectedAdjac /// * `fwd` - forward traversal segments, the "out-edges" of the nodes /// * `rev` - reverse traversal segments, the "in-edges" of the nodes /// * `nodes` - the graph nodes included to find components. -/// this can either be the complete set or a subset. +/// this can either be the complete set or a subset. /// /// # Result /// diff --git a/rust/bambam-osm/src/algorithm/consolidation/consolidation_ops.rs b/rust/bambam-osm/src/algorithm/consolidation/consolidation_ops.rs index a164d9f0..bfba8cdc 100644 --- a/rust/bambam-osm/src/algorithm/consolidation/consolidation_ops.rs +++ b/rust/bambam-osm/src/algorithm/consolidation/consolidation_ops.rs @@ -27,9 +27,9 @@ use std::sync::Mutex; /// /// * `graph` - the original graph data from the .pbf file /// * `tolerance` - edge-connected endpoints within this distance threshold are merged -/// into a new graph vertex by their centroid +/// into a new graph vertex by their centroid /// * `ignore_osm_parsing_errors` - if true, do not fail if a maxspeed or other attribute is not -/// valid wrt the OpenStreetMaps documentation +/// valid wrt the OpenStreetMaps documentation pub fn consolidate_graph( graph: &mut OsmGraph, tolerance: uom::si::f64::Length, @@ -339,7 +339,7 @@ fn consolidate_nodes(node_ids: Vec, graph: &mut OsmGraph) -> Result<( /// modifies the way.nodes collection so it does not include any removed nodes and /// the new consolidated node is inserted in the correct place depending on the way direction. fn update_way_nodes( - ways: &mut Vec, + ways: &mut [OsmWayData], new_node_id: &OsmNodeId, remove_nodes: &HashSet, dir: &AdjacencyDirection, @@ -374,8 +374,7 @@ fn update_way_nodes( /// geometry indices. /// /// # Arguments -/// * `geometry_indices` - indices into the spatial intersection vector that -/// will be considered for clustering +/// * `geometry_indices` - indices into the spatial intersection vector that will be considered for clustering /// * `simplified` - the simplified graph /// * `endpoint_index_osmid_mapping` - maps indices to Node OSMIDs /// diff --git a/rust/bambam-osm/src/algorithm/simplification/simplify_ops.rs b/rust/bambam-osm/src/algorithm/simplification/simplify_ops.rs index 52296a83..74576487 100644 --- a/rust/bambam-osm/src/algorithm/simplification/simplify_ops.rs +++ b/rust/bambam-osm/src/algorithm/simplification/simplify_ops.rs @@ -485,17 +485,13 @@ fn build_path( /// /// 1) It is its own neighbor (ie, it self-loops). /// -/// 2) Or, it has no incoming edges or no outgoing edges (ie, all its incident -/// edges are inbound or all its incident edges are outbound). +/// 2) Or, it has no incoming edges or no outgoing edges (ie, all its incident edges are inbound or all its incident edges are outbound). /// /// 3) Or, it does not have exactly two neighbors and degree of 2 or 4. /// -/// 4) Or, if `node_attrs_include` is not None and it has one or more of the -/// attributes in `node_attrs_include`. +/// 4) Or, if `node_attrs_include` is not None and it has one or more of the attributes in `node_attrs_include`. /// -/// 5) Or, if `edge_attrs_differ` is not None and its incident edges have -/// different values than each other for any of the edge attributes in -/// `edge_attrs_differ`. +/// 5) Or, if `edge_attrs_differ` is not None and its incident edges have different values than each other for any of the edge attributes in `edge_attrs_differ`. fn node_is_endpoint(id: &OsmNodeId, graph: &OsmGraph) -> Result { // neighbors is the set of unique nodes connected to this node let succ = graph.get_out_neighbors(id).unwrap_or_default();