diff --git a/src/pymatching/sparse_blossom/driver/user_graph.cc b/src/pymatching/sparse_blossom/driver/user_graph.cc index 19734730a..313343978 100644 --- a/src/pymatching/sparse_blossom/driver/user_graph.cc +++ b/src/pymatching/sparse_blossom/driver/user_graph.cc @@ -17,6 +17,19 @@ #include "pymatching/rand/rand_gen.h" #include "pymatching/sparse_blossom/driver/implied_weights.h" +namespace { + +double bernoulli_xor(double p1, double p2) { + return p1 * (1 - p2) + p2 * (1 - p1); +} + +} // namespace + + +double pm::to_weight_for_correlations(double probability) { + return std::log((1 - probability) / probability); +} + double pm::merge_weights(double a, double b) { auto sgn = std::copysign(1, a) * std::copysign(1, b); auto signed_min = sgn * std::min(std::abs(a), std::abs(b)); @@ -342,6 +355,15 @@ void pm::UserGraph::handle_dem_instruction( } } +void pm::UserGraph::handle_dem_instruction_include_correlations( + double p, const std::vector& detectors, const std::vector& observables) { + if (detectors.size() == 2) { + add_or_merge_edge(detectors[0], detectors[1], observables, pm::to_weight_for_correlations(p), p, INDEPENDENT); + } else if (detectors.size() == 1) { + add_or_merge_boundary_edge(detectors[0], observables, pm::to_weight_for_correlations(p), p, INDEPENDENT); + } +} + void pm::UserGraph::get_nodes_on_shortest_path_from_source(size_t src, size_t dst, std::vector& out_nodes) { auto& mwpm = get_mwpm_with_search_graph(); bool src_is_boundary = is_boundary_node(src); @@ -444,18 +466,6 @@ double pm::UserGraph::get_edge_weight_normalising_constant(size_t max_num_distin } } -namespace { - -double bernoulli_xor(double p1, double p2) { - return p1 * (1 - p2) + p2 * (1 - p1); -} - -double to_weight(double probability) { - return std::log((1 - probability) / probability); -} - -} // namespace - void pm::add_decomposed_error_to_joint_probabilities( DecomposedDemError& error, std::map, std::map, double>>& joint_probabilites) { @@ -490,7 +500,7 @@ pm::UserGraph pm::detector_error_model_to_user_graph( pm::iter_dem_instructions_include_correlations( detector_error_model, [&](double p, const std::vector& detectors, std::vector& observables) { - user_graph.handle_dem_instruction(p, detectors, observables); + user_graph.handle_dem_instruction_include_correlations(p, detectors, observables); }, joint_probabilites); @@ -526,7 +536,7 @@ void pm::UserGraph::populate_implied_edge_weights( // minimum of 0.5 as an implied probability for an edge to be reweighted. double implied_probability_for_other_edge = std::min(0.5, affected_edge_and_probability.second / marginal_probability); - double w = to_weight(implied_probability_for_other_edge); + double w = pm::to_weight_for_correlations(implied_probability_for_other_edge); ImpliedWeightUnconverted implied{affected_edge.first, affected_edge.second, w}; edge.implied_weights_for_other_edges.push_back(implied); } diff --git a/src/pymatching/sparse_blossom/driver/user_graph.h b/src/pymatching/sparse_blossom/driver/user_graph.h index ab945cb4b..a8c6ff8ff 100644 --- a/src/pymatching/sparse_blossom/driver/user_graph.h +++ b/src/pymatching/sparse_blossom/driver/user_graph.h @@ -120,6 +120,8 @@ class UserGraph { Mwpm& get_mwpm(); Mwpm& get_mwpm_with_search_graph(); void handle_dem_instruction(double p, const std::vector& detectors, const std::vector& observables); + void handle_dem_instruction_include_correlations( + double p, const std::vector& detectors, const std::vector& observables); void get_nodes_on_shortest_path_from_source(size_t src, size_t dst, std::vector& out_nodes); void populate_implied_edge_weights( std::map, std::map, double>>& joint_probabilites); @@ -131,6 +133,8 @@ class UserGraph { bool _all_edges_have_error_probabilities; }; +double to_weight_for_correlations(double probability); + template inline double UserGraph::iter_discretized_edges( pm::weight_int num_distinct_weights, @@ -263,7 +267,8 @@ template void iter_dem_instructions_include_correlations( const stim::DetectorErrorModel& detector_error_model, const Handler& handle_dem_error, - std::map, std::map, double>>& joint_probabilites) { + std::map, std::map, double>>& joint_probabilites, + bool include_decomposed_error_components_in_edge_weights = true) { detector_error_model.iter_flatten_error_instructions([&](const stim::DemInstruction& instruction) { double p = instruction.arg_data[0]; pm::DecomposedDemError decomposed_err; @@ -311,35 +316,23 @@ void iter_dem_instructions_include_correlations( component->observable_indices.push_back(target.val()); } else if (target.is_separator()) { instruction_contains_separator = true; - // If the previous error in the decomposition had 3 or more detectors, we throw an exception. - if (num_component_detectors > 2) { - throw std::invalid_argument( - "Encountered a decomposed error instruction with a hyperedge component (3 or more detectors). " - "This is not supported."); - } else if (num_component_detectors == 0) { + // We cannot have num_component_detectors > 2 at this point, or we would have already thrown an + // exception + if (num_component_detectors == 0) { throw std::invalid_argument( "Encountered a decomposed error instruction with an undetectable component (0 detectors). " "This is not supported."); - } else if (num_component_detectors > 0) { - // If the previous error in the decomposition had 1 or 2 detectors, we handle it - handle_dem_error(p, {component->node1, component->node2}, component->observable_indices); - decomposed_err.components.push_back({}); - component = &decomposed_err.components.back(); - component->node1 = SIZE_MAX; - component->node2 = SIZE_MAX; - num_component_detectors = 0; } + // The previous error in the decomposition must have 1 or 2 detectors + decomposed_err.components.push_back({}); + component = &decomposed_err.components.back(); + component->node1 = SIZE_MAX; + component->node2 = SIZE_MAX; + num_component_detectors = 0; } } - if (num_component_detectors > 2) { - // Undecomposed hyperedges are not supported - throw std::invalid_argument( - "Encountered an undecomposed error instruction with 3 or mode detectors. " - "This is not supported when using `enable_correlations=True`. " - "Did you forget to set `decompose_errors=True` when " - "converting the stim circuit to a detector error model?"); - } else if (num_component_detectors == 0) { + if (num_component_detectors == 0) { if (instruction_contains_separator) { throw std::invalid_argument( "Encountered a decomposed error instruction with an undetectable component (0 detectors). " @@ -348,12 +341,17 @@ void iter_dem_instructions_include_correlations( // Ignore errors that are undetectable, provided they are not a component of a decomposed error return; } + } - } else if (num_component_detectors > 0) { - if (component->node2 == SIZE_MAX) { - handle_dem_error(p, {component->node1}, component->observable_indices); - } else { - handle_dem_error(p, {component->node1, component->node2}, component->observable_indices); + // If include_decomposed_error_components_in_edge_weights, then only add the edge into the graph if + // it is not a component in a decomposed error with more than one component + if (include_decomposed_error_components_in_edge_weights || decomposed_err.components.size() == 1) { + for (pm::UserEdge& component : decomposed_err.components) { + if (component.node2 == SIZE_MAX) { + handle_dem_error(p, {component.node1}, component.observable_indices); + } else { + handle_dem_error(p, {component.node1, component.node2}, component.observable_indices); + } } } diff --git a/src/pymatching/sparse_blossom/driver/user_graph.test.cc b/src/pymatching/sparse_blossom/driver/user_graph.test.cc index cd08bf18e..5d1e08bfe 100644 --- a/src/pymatching/sparse_blossom/driver/user_graph.test.cc +++ b/src/pymatching/sparse_blossom/driver/user_graph.test.cc @@ -384,7 +384,7 @@ TEST(IterDemInstructionsTest, CombinedComplexDem) { pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities); ASSERT_EQ(handler.handled_errors.size(), 4); - + std::vector expected = { {0.1, 0, SIZE_MAX, {}}, {0.2, 1, 2, {0}}, {0.4, 8, SIZE_MAX, {}}, {0.4, 9, SIZE_MAX, {1}}}; EXPECT_EQ(handler.handled_errors, expected); @@ -485,14 +485,6 @@ TEST(UserGraph, PopulateImpliedEdgeWeights) { graph.populate_implied_edge_weights(joint_probabilities); - auto to_weight = [](double p) { - if (p == 1.0) - return -std::numeric_limits::infinity(); - if (p == 0.0) - return std::numeric_limits::infinity(); - return std::log((1 - p) / p); - }; - auto it_01 = std::find_if(graph.edges.begin(), graph.edges.end(), [](const pm::UserEdge& edge) { return edge.node1 == 0 && edge.node2 == 1; }); @@ -503,7 +495,7 @@ TEST(UserGraph, PopulateImpliedEdgeWeights) { ASSERT_EQ(implied_01.node2, 3); double p_01 = 0.1 / 0.26; - double w_01 = to_weight(p_01); + double w_01 = pm::to_weight_for_correlations(p_01); ASSERT_EQ(implied_01.implied_weight, w_01); auto it_23 = std::find_if(graph.edges.begin(), graph.edges.end(), [](const pm::UserEdge& edge) { @@ -515,7 +507,7 @@ TEST(UserGraph, PopulateImpliedEdgeWeights) { const auto& implied_23 = it_23->implied_weights_for_other_edges[0]; ASSERT_EQ(implied_23.node1, 0); ASSERT_EQ(implied_23.node2, 1); - ASSERT_EQ(implied_23.implied_weight, 0); + ASSERT_NEAR(implied_23.implied_weight, 0.0, 0.00001); } TEST(UserGraph, ConvertImpliedWeights) {