Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions src/pymatching/sparse_blossom/driver/user_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@
#include "pymatching/rand/rand_gen.h"
#include "pymatching/sparse_blossom/driver/implied_weights.h"

namespace {

double bernoulli_xor(double p1, double p2) {
return p1 * (1 - p2) + p2 * (1 - p1);
}

} // namespace


double pm::to_weight_for_correlations(double probability) {
return std::log((1 - probability) / probability);
}

double pm::merge_weights(double a, double b) {
auto sgn = std::copysign(1, a) * std::copysign(1, b);
auto signed_min = sgn * std::min(std::abs(a), std::abs(b));
Expand Down Expand Up @@ -342,6 +355,15 @@ void pm::UserGraph::handle_dem_instruction(
}
}

void pm::UserGraph::handle_dem_instruction_include_correlations(
double p, const std::vector<size_t>& detectors, const std::vector<size_t>& observables) {
if (detectors.size() == 2) {
add_or_merge_edge(detectors[0], detectors[1], observables, pm::to_weight_for_correlations(p), p, INDEPENDENT);
} else if (detectors.size() == 1) {
add_or_merge_boundary_edge(detectors[0], observables, pm::to_weight_for_correlations(p), p, INDEPENDENT);
}
}

void pm::UserGraph::get_nodes_on_shortest_path_from_source(size_t src, size_t dst, std::vector<size_t>& out_nodes) {
auto& mwpm = get_mwpm_with_search_graph();
bool src_is_boundary = is_boundary_node(src);
Expand Down Expand Up @@ -444,18 +466,6 @@ double pm::UserGraph::get_edge_weight_normalising_constant(size_t max_num_distin
}
}

namespace {

double bernoulli_xor(double p1, double p2) {
return p1 * (1 - p2) + p2 * (1 - p1);
}

double to_weight(double probability) {
return std::log((1 - probability) / probability);
}

} // namespace

void pm::add_decomposed_error_to_joint_probabilities(
DecomposedDemError& error,
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>>& joint_probabilites) {
Expand Down Expand Up @@ -490,7 +500,7 @@ pm::UserGraph pm::detector_error_model_to_user_graph(
pm::iter_dem_instructions_include_correlations(
detector_error_model,
[&](double p, const std::vector<size_t>& detectors, std::vector<size_t>& observables) {
user_graph.handle_dem_instruction(p, detectors, observables);
user_graph.handle_dem_instruction_include_correlations(p, detectors, observables);
},
joint_probabilites);

Expand Down Expand Up @@ -526,7 +536,7 @@ void pm::UserGraph::populate_implied_edge_weights(
// minimum of 0.5 as an implied probability for an edge to be reweighted.
double implied_probability_for_other_edge =
std::min(0.5, affected_edge_and_probability.second / marginal_probability);
double w = to_weight(implied_probability_for_other_edge);
double w = pm::to_weight_for_correlations(implied_probability_for_other_edge);
ImpliedWeightUnconverted implied{affected_edge.first, affected_edge.second, w};
edge.implied_weights_for_other_edges.push_back(implied);
}
Expand Down
54 changes: 26 additions & 28 deletions src/pymatching/sparse_blossom/driver/user_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ class UserGraph {
Mwpm& get_mwpm();
Mwpm& get_mwpm_with_search_graph();
void handle_dem_instruction(double p, const std::vector<size_t>& detectors, const std::vector<size_t>& observables);
void handle_dem_instruction_include_correlations(
double p, const std::vector<size_t>& detectors, const std::vector<size_t>& observables);
void get_nodes_on_shortest_path_from_source(size_t src, size_t dst, std::vector<size_t>& out_nodes);
void populate_implied_edge_weights(
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>>& joint_probabilites);
Expand All @@ -131,6 +133,8 @@ class UserGraph {
bool _all_edges_have_error_probabilities;
};

double to_weight_for_correlations(double probability);

template <typename EdgeCallable, typename BoundaryEdgeCallable>
inline double UserGraph::iter_discretized_edges(
pm::weight_int num_distinct_weights,
Expand Down Expand Up @@ -263,7 +267,8 @@ template <typename Handler>
void iter_dem_instructions_include_correlations(
const stim::DetectorErrorModel& detector_error_model,
const Handler& handle_dem_error,
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>>& joint_probabilites) {
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>>& joint_probabilites,
bool include_decomposed_error_components_in_edge_weights = true) {
detector_error_model.iter_flatten_error_instructions([&](const stim::DemInstruction& instruction) {
double p = instruction.arg_data[0];
pm::DecomposedDemError decomposed_err;
Expand Down Expand Up @@ -311,35 +316,23 @@ void iter_dem_instructions_include_correlations(
component->observable_indices.push_back(target.val());
} else if (target.is_separator()) {
instruction_contains_separator = true;
// If the previous error in the decomposition had 3 or more detectors, we throw an exception.
if (num_component_detectors > 2) {
throw std::invalid_argument(
"Encountered a decomposed error instruction with a hyperedge component (3 or more detectors). "
"This is not supported.");
} else if (num_component_detectors == 0) {
// We cannot have num_component_detectors > 2 at this point, or we would have already thrown an
// exception
if (num_component_detectors == 0) {
throw std::invalid_argument(
"Encountered a decomposed error instruction with an undetectable component (0 detectors). "
"This is not supported.");
} else if (num_component_detectors > 0) {
// If the previous error in the decomposition had 1 or 2 detectors, we handle it
handle_dem_error(p, {component->node1, component->node2}, component->observable_indices);
decomposed_err.components.push_back({});
component = &decomposed_err.components.back();
component->node1 = SIZE_MAX;
component->node2 = SIZE_MAX;
num_component_detectors = 0;
}
// The previous error in the decomposition must have 1 or 2 detectors
decomposed_err.components.push_back({});
component = &decomposed_err.components.back();
component->node1 = SIZE_MAX;
component->node2 = SIZE_MAX;
num_component_detectors = 0;
}
}

if (num_component_detectors > 2) {
// Undecomposed hyperedges are not supported
throw std::invalid_argument(
"Encountered an undecomposed error instruction with 3 or mode detectors. "
"This is not supported when using `enable_correlations=True`. "
"Did you forget to set `decompose_errors=True` when "
"converting the stim circuit to a detector error model?");
} else if (num_component_detectors == 0) {
if (num_component_detectors == 0) {
if (instruction_contains_separator) {
throw std::invalid_argument(
"Encountered a decomposed error instruction with an undetectable component (0 detectors). "
Expand All @@ -348,12 +341,17 @@ void iter_dem_instructions_include_correlations(
// Ignore errors that are undetectable, provided they are not a component of a decomposed error
return;
}
}

} else if (num_component_detectors > 0) {
if (component->node2 == SIZE_MAX) {
handle_dem_error(p, {component->node1}, component->observable_indices);
} else {
handle_dem_error(p, {component->node1, component->node2}, component->observable_indices);
// If include_decomposed_error_components_in_edge_weights, then only add the edge into the graph if
// it is not a component in a decomposed error with more than one component
if (include_decomposed_error_components_in_edge_weights || decomposed_err.components.size() == 1) {
for (pm::UserEdge& component : decomposed_err.components) {
if (component.node2 == SIZE_MAX) {
handle_dem_error(p, {component.node1}, component.observable_indices);
} else {
handle_dem_error(p, {component.node1, component.node2}, component.observable_indices);
}
}
}

Expand Down
14 changes: 3 additions & 11 deletions src/pymatching/sparse_blossom/driver/user_graph.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ TEST(IterDemInstructionsTest, CombinedComplexDem) {
pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities);

ASSERT_EQ(handler.handled_errors.size(), 4);

std::vector<HandledError> expected = {
{0.1, 0, SIZE_MAX, {}}, {0.2, 1, 2, {0}}, {0.4, 8, SIZE_MAX, {}}, {0.4, 9, SIZE_MAX, {1}}};
EXPECT_EQ(handler.handled_errors, expected);
Expand Down Expand Up @@ -485,14 +485,6 @@ TEST(UserGraph, PopulateImpliedEdgeWeights) {

graph.populate_implied_edge_weights(joint_probabilities);

auto to_weight = [](double p) {
if (p == 1.0)
return -std::numeric_limits<double>::infinity();
if (p == 0.0)
return std::numeric_limits<double>::infinity();
return std::log((1 - p) / p);
};

auto it_01 = std::find_if(graph.edges.begin(), graph.edges.end(), [](const pm::UserEdge& edge) {
return edge.node1 == 0 && edge.node2 == 1;
});
Expand All @@ -503,7 +495,7 @@ TEST(UserGraph, PopulateImpliedEdgeWeights) {
ASSERT_EQ(implied_01.node2, 3);

double p_01 = 0.1 / 0.26;
double w_01 = to_weight(p_01);
double w_01 = pm::to_weight_for_correlations(p_01);
ASSERT_EQ(implied_01.implied_weight, w_01);

auto it_23 = std::find_if(graph.edges.begin(), graph.edges.end(), [](const pm::UserEdge& edge) {
Expand All @@ -515,7 +507,7 @@ TEST(UserGraph, PopulateImpliedEdgeWeights) {
const auto& implied_23 = it_23->implied_weights_for_other_edges[0];
ASSERT_EQ(implied_23.node1, 0);
ASSERT_EQ(implied_23.node2, 1);
ASSERT_EQ(implied_23.implied_weight, 0);
ASSERT_NEAR(implied_23.implied_weight, 0.0, 0.00001);
}

TEST(UserGraph, ConvertImpliedWeights) {
Expand Down
Loading