diff --git a/src/dsf/bindings.cpp b/src/dsf/bindings.cpp index b0a8fdcd2..e7848d645 100644 --- a/src/dsf/bindings.cpp +++ b/src/dsf/bindings.cpp @@ -208,6 +208,11 @@ PYBIND11_MODULE(dsf_cpp, m) { &dsf::mobility::RoadNetwork::importTrafficLights, pybind11::arg("fileName"), dsf::g_docstrings.at("dsf::mobility::RoadNetwork::importTrafficLights").c_str()) + .def("setTransitionProbabilities", + &dsf::mobility::RoadNetwork::setTransitionProbabilities, + pybind11::arg("transitionProbabilities"), + dsf::g_docstrings.at("dsf::mobility::RoadNetwork::setTransitionProbabilities") + .c_str()) .def( "makeRoundabout", [](dsf::mobility::RoadNetwork& self, dsf::Id id) -> void { diff --git a/src/dsf/mobility/RoadDynamics.hpp b/src/dsf/mobility/RoadDynamics.hpp index 4ee730ad1..81eca8145 100644 --- a/src/dsf/mobility/RoadDynamics.hpp +++ b/src/dsf/mobility/RoadDynamics.hpp @@ -496,12 +496,14 @@ namespace dsf::mobility { // Get current street information std::optional previousNodeId = std::nullopt; std::set forbiddenTurns; + std::unordered_map baseTransitionProbabilities; double speedCurrent = 1.0; if (pAgent->streetId().has_value()) { auto const& pStreetCurrent{this->graph().edge(pAgent->streetId().value())}; previousNodeId = pStreetCurrent->source(); forbiddenTurns = pStreetCurrent->forbiddenTurns(); speedCurrent = pStreetCurrent->maxSpeed(); + baseTransitionProbabilities = pStreetCurrent->transitionProbabilities(); } // Get path targets for non-random agents @@ -551,6 +553,9 @@ namespace dsf::mobility { // Calculate base probability auto const speedNext{pStreetOut->maxSpeed()}; double probability = speedCurrent * speedNext; + if (!baseTransitionProbabilities.empty()) { + probability *= baseTransitionProbabilities.at(outEdgeId); // edge-specific prob + } // Apply error probability for non-random agents if (this->m_errorProbability.has_value() && !pathTargets.empty()) { diff --git a/src/dsf/mobility/RoadNetwork.cpp b/src/dsf/mobility/RoadNetwork.cpp index dce56928e..ce297779d 100644 --- a/src/dsf/mobility/RoadNetwork.cpp +++ b/src/dsf/mobility/RoadNetwork.cpp @@ -264,17 +264,17 @@ namespace dsf::mobility { } } // Check for transition probabilities - // if (!edge_properties.at_key("transition_probabilities").error() && - // edge_properties["transition_probabilities"].has_value()) { - // auto const& tp = edge_properties["transition_probabilities"]; - // std::unordered_map transitionProbabilities; - // for (auto const& [key, value] : tp.get_object()) { - // auto const targetStreetId = static_cast(std::stoull(std::string(key))); - // auto const probability = static_cast(value.get_double()); - // transitionProbabilities.emplace(targetStreetId, probability); - // } - // edge(edge_id)->setTransitionProbabilities(transitionProbabilities); - // } + if (!edge_properties.at_key("transition_probabilities").error() && + edge_properties["transition_probabilities"].has_value()) { + auto const& tp = edge_properties["transition_probabilities"]; + std::unordered_map transitionProbabilities; + for (auto const& [key, value] : tp.get_object()) { + auto const targetStreetId = static_cast(std::stoull(std::string(key))); + auto const probability = static_cast(value.get_double()); + transitionProbabilities.emplace(targetStreetId, probability); + } + edge(edge_id)->setTransitionProbabilities(transitionProbabilities); + } } this->m_nodes.rehash(0); this->m_edges.rehash(0); @@ -792,6 +792,23 @@ namespace dsf::mobility { } } + void RoadNetwork::setTransitionProbabilities( + std::unordered_map> const& + transitionProbabilities) { + std::for_each(DSF_EXECUTION m_edges.cbegin(), + m_edges.cend(), + [&transitionProbabilities](auto const& pair) { + auto const& streetId = pair.first; + auto const& pStreet = pair.second; + auto const it = transitionProbabilities.find(streetId); + if (it != transitionProbabilities.end()) { + pStreet->setTransitionProbabilities(it->second); + } else { + pStreet->setTransitionProbabilities({}); + } + }); + } + TrafficLight& RoadNetwork::makeTrafficLight(Id const nodeId, Delay const cycleTime, Delay const counter) { diff --git a/src/dsf/mobility/RoadNetwork.hpp b/src/dsf/mobility/RoadNetwork.hpp index d23af7d19..0b470a42b 100644 --- a/src/dsf/mobility/RoadNetwork.hpp +++ b/src/dsf/mobility/RoadNetwork.hpp @@ -139,6 +139,10 @@ namespace dsf::mobility { /// and the speed limit, if such data is available in the file. void importTrafficLights(const std::string& fileName); + void setTransitionProbabilities( + std::unordered_map> const& + transitionProbabilities); + template requires is_node_v> && (is_node_v> && ...) diff --git a/src/dsf/mobility/Street.cpp b/src/dsf/mobility/Street.cpp index 5f5dd1d83..203e5b0cc 100644 --- a/src/dsf/mobility/Street.cpp +++ b/src/dsf/mobility/Street.cpp @@ -70,10 +70,28 @@ namespace dsf::mobility { strLaneMapping += std::format("{} - ", directionToString[static_cast(item)]); }); - spdlog::debug("New lane mapping for street {} -> {} is: {}", - m_nodePair.first, - m_nodePair.second, - strLaneMapping); + spdlog::debug("New lane mapping for {} is: {}", *this, strLaneMapping); + } + void Street::setTransitionProbabilities( + std::unordered_map const& transitionProbabilities) noexcept { + // Ensure normalization + if (transitionProbabilities.empty()) { + m_transitionProbabilities.clear(); + return; + } + double sumProbabilities{0.}; + for (auto const& [_, probability] : transitionProbabilities) { + sumProbabilities += probability; + } + if (std::abs(sumProbabilities - 1.) > 1e-6) { + auto tp = transitionProbabilities; + for (auto& [_, probability] : tp) { + probability /= sumProbabilities; + } + m_transitionProbabilities = tp; + return; + } + m_transitionProbabilities = transitionProbabilities; } void Street::setQueue(dsf::queue> queue, size_t index) { assert(index < m_exitQueues.size()); diff --git a/src/dsf/mobility/Street.hpp b/src/dsf/mobility/Street.hpp index c4dce920e..635ee7ac5 100644 --- a/src/dsf/mobility/Street.hpp +++ b/src/dsf/mobility/Street.hpp @@ -48,7 +48,7 @@ namespace dsf::mobility { AgentComparator> m_movingAgents; std::vector m_laneMapping; - // std::unordered_map m_transitionProbabilities; + std::unordered_map m_transitionProbabilities; std::optional m_counter; public: @@ -84,12 +84,10 @@ namespace dsf::mobility { /// @param meanVehicleLength The mean vehicle length /// @throw std::invalid_argument If the mean vehicle length is negative static void setMeanVehicleLength(double meanVehicleLength); - // /// @brief Set the street's transition probabilities - // /// @param transitionProbabilities The street's transition probabilities - // inline void setTransitionProbabilities( - // std::unordered_map const& transitionProbabilities) { - // m_transitionProbabilities = transitionProbabilities; - // }; + /// @brief Set the street's transition probabilities + /// @param transitionProbabilities The street's transition probabilities + void setTransitionProbabilities( + std::unordered_map const& transitionProbabilities) noexcept; /// @brief Enable a coil (dsf::Counter sensor) on the street /// @param name The name of the counter (default is "Coil_") void enableCounter(std::string name = std::string()); @@ -117,10 +115,11 @@ namespace dsf::mobility { /// @brief Check if the street is full /// @return bool, True if the street is full, false otherwise inline bool isFull() const final { return this->nAgents() == this->m_capacity; } - - // inline auto const& transitionProbabilities() const { - // return m_transitionProbabilities; - // } + /// @brief Get the street's transition probabilities + /// @return std::unordered_map The street's transition probabilities + inline auto const& transitionProbabilities() const { + return m_transitionProbabilities; + } /// @brief Get the name of the counter /// @return std::string The name of the counter inline auto counterName() const { diff --git a/test/data/test_transition_probs.geojson b/test/data/test_transition_probs.geojson new file mode 100644 index 000000000..03e2127e5 --- /dev/null +++ b/test/data/test_transition_probs.geojson @@ -0,0 +1,88 @@ +{ + "type": "FeatureCollection", + "features": [ + { + "type": "Feature", + "geometry": { + "type": "LineString", + "coordinates": [ + [11.8810, 44.2230], + [11.8820, 44.2240] + ] + }, + "properties": { + "id": 1, + "source": 0, + "target": 1, + "length": 100.5, + "maxspeed": 50, + "nlanes": 2, + "name": "Test Street A", + "transition_probabilities": { + "2": 0.6, + "3": 0.4 + } + } + }, + { + "type": "Feature", + "geometry": { + "type": "LineString", + "coordinates": [ + [11.8820, 44.2240], + [11.8830, 44.2250] + ] + }, + "properties": { + "id": 2, + "source": 1, + "target": 2, + "length": 80.0, + "maxspeed": "30", + "nlanes": 1, + "name": "Test Street B" + } + }, + { + "type": "Feature", + "geometry": { + "type": "LineString", + "coordinates": [ + [11.8820, 44.2240], + [11.8815, 44.2260] + ] + }, + "properties": { + "id": 3, + "source": 1, + "target": 3, + "length": 120.0, + "maxspeed": 40, + "nlanes": 1, + "name": "Test Street C", + "transition_probabilities": { + "4": 1.0 + } + } + }, + { + "type": "Feature", + "geometry": { + "type": "LineString", + "coordinates": [ + [11.8815, 44.2260], + [11.8810, 44.2270] + ] + }, + "properties": { + "id": 4, + "source": 3, + "target": 4, + "length": 90.0, + "maxspeed": 50, + "nlanes": 2, + "name": "Test Street D" + } + } + ] +} diff --git a/test/mobility/Test_graph.cpp b/test/mobility/Test_graph.cpp index 188fe3929..2bf8c3fbd 100644 --- a/test/mobility/Test_graph.cpp +++ b/test/mobility/Test_graph.cpp @@ -899,4 +899,60 @@ TEST_CASE("ShortestPath") { CHECK_EQ(path.size(), 1); CHECK_EQ(path[0], 0); } + + SUBCASE("Import GeoJSON with Transition Probabilities") { + GIVEN("A GeoJSON file with transition probabilities") { + RoadNetwork graph; + + WHEN("We import the GeoJSON file") { + graph.importEdges((DATA_FOLDER / "test_transition_probs.geojson").string()); + + THEN("The graph is constructed correctly") { + CHECK_EQ(graph.nEdges(), 4); + CHECK_EQ(graph.nNodes(), 5); + } + + THEN("Transition probabilities are correctly imported for street 1") { + auto const& street1 = graph.edge(1); + auto const& transProbs1 = street1->transitionProbabilities(); + + CHECK_EQ(transProbs1.size(), 2); + CHECK(transProbs1.contains(2)); + CHECK(transProbs1.contains(3)); + CHECK_EQ(transProbs1.at(2), doctest::Approx(0.6)); + CHECK_EQ(transProbs1.at(3), doctest::Approx(0.4)); + } + + THEN("Transition probabilities are correctly imported for street 3") { + auto const& street3 = graph.edge(3); + auto const& transProbs3 = street3->transitionProbabilities(); + + CHECK_EQ(transProbs3.size(), 1); + CHECK(transProbs3.contains(4)); + CHECK_EQ(transProbs3.at(4), doctest::Approx(1.0)); + } + + THEN("Streets without transition probabilities have empty maps") { + auto const& street2 = graph.edge(2); + auto const& transProbs2 = street2->transitionProbabilities(); + CHECK_EQ(transProbs2.size(), 0); + + auto const& street4 = graph.edge(4); + auto const& transProbs4 = street4->transitionProbabilities(); + CHECK_EQ(transProbs4.size(), 0); + } + + THEN("Other street properties are imported correctly") { + auto const& street1 = graph.edge(1); + CHECK_EQ(street1->id(), 1); + CHECK_EQ(street1->source(), 0); + CHECK_EQ(street1->target(), 1); + CHECK_EQ(street1->length(), doctest::Approx(100.5)); + CHECK_EQ(street1->maxSpeed(), doctest::Approx(50.0 / 3.6)); + CHECK_EQ(street1->nLanes(), 2); + CHECK_EQ(street1->name(), "Test Street A"); + } + } + } + } }