Skip to content

Commit 6c33cdb

Browse files
Handle boundary edge components and refactor iter_dem_instructions_include_correlations (#171)
* Add more exhaustive tests of correlations. Fix a bug where edge weight rewrites in the search graph weren't always undone. * handle boundary edge component properly (fixing issue #175) and refactor * Add argument to determine if components are included as edges. Fix choice to include components. * remove support for cp38-macosx_arm64 * Add test to check for correlated matching boundary component handling (#176) Co-authored-by: oscarhiggott <29460323+oscarhiggott@users.noreply.github.com> --------- Co-authored-by: Noah Shutty <noajshu@users.noreply.github.com>
1 parent 3d5919c commit 6c33cdb

File tree

5 files changed

+74
-54
lines changed

5 files changed

+74
-54
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ jobs:
5151
{os: macos-latest, dist: cp312-macosx_x86_64},
5252
{os: macos-latest, dist: cp313-macosx_x86_64},
5353
# macosx arm64
54-
{os: macos-latest, dist: cp38-macosx_arm64},
54+
# {os: macos-latest, dist: cp38-macosx_arm64},
5555
{os: macos-latest, dist: cp39-macosx_arm64},
5656
{os: macos-latest, dist: cp310-macosx_arm64},
5757
{os: macos-latest, dist: cp311-macosx_arm64},

src/pymatching/sparse_blossom/driver/user_graph.cc

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,19 @@
1717
#include "pymatching/rand/rand_gen.h"
1818
#include "pymatching/sparse_blossom/driver/implied_weights.h"
1919

20+
namespace {
21+
22+
double bernoulli_xor(double p1, double p2) {
23+
return p1 * (1 - p2) + p2 * (1 - p1);
24+
}
25+
26+
} // namespace
27+
28+
29+
double pm::to_weight_for_correlations(double probability) {
30+
return std::log((1 - probability) / probability);
31+
}
32+
2033
double pm::merge_weights(double a, double b) {
2134
auto sgn = std::copysign(1, a) * std::copysign(1, b);
2235
auto signed_min = sgn * std::min(std::abs(a), std::abs(b));
@@ -342,6 +355,15 @@ void pm::UserGraph::handle_dem_instruction(
342355
}
343356
}
344357

358+
void pm::UserGraph::handle_dem_instruction_include_correlations(
359+
double p, const std::vector<size_t>& detectors, const std::vector<size_t>& observables) {
360+
if (detectors.size() == 2) {
361+
add_or_merge_edge(detectors[0], detectors[1], observables, pm::to_weight_for_correlations(p), p, INDEPENDENT);
362+
} else if (detectors.size() == 1) {
363+
add_or_merge_boundary_edge(detectors[0], observables, pm::to_weight_for_correlations(p), p, INDEPENDENT);
364+
}
365+
}
366+
345367
void pm::UserGraph::get_nodes_on_shortest_path_from_source(size_t src, size_t dst, std::vector<size_t>& out_nodes) {
346368
auto& mwpm = get_mwpm_with_search_graph();
347369
bool src_is_boundary = is_boundary_node(src);
@@ -444,18 +466,6 @@ double pm::UserGraph::get_edge_weight_normalising_constant(size_t max_num_distin
444466
}
445467
}
446468

447-
namespace {
448-
449-
double bernoulli_xor(double p1, double p2) {
450-
return p1 * (1 - p2) + p2 * (1 - p1);
451-
}
452-
453-
double to_weight(double probability) {
454-
return std::log((1 - probability) / probability);
455-
}
456-
457-
} // namespace
458-
459469
void pm::add_decomposed_error_to_joint_probabilities(
460470
DecomposedDemError& error,
461471
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>>& joint_probabilites) {
@@ -490,7 +500,7 @@ pm::UserGraph pm::detector_error_model_to_user_graph(
490500
pm::iter_dem_instructions_include_correlations(
491501
detector_error_model,
492502
[&](double p, const std::vector<size_t>& detectors, std::vector<size_t>& observables) {
493-
user_graph.handle_dem_instruction(p, detectors, observables);
503+
user_graph.handle_dem_instruction_include_correlations(p, detectors, observables);
494504
},
495505
joint_probabilites);
496506

@@ -526,7 +536,7 @@ void pm::UserGraph::populate_implied_edge_weights(
526536
// minimum of 0.5 as an implied probability for an edge to be reweighted.
527537
double implied_probability_for_other_edge =
528538
std::min(0.5, affected_edge_and_probability.second / marginal_probability);
529-
double w = to_weight(implied_probability_for_other_edge);
539+
double w = pm::to_weight_for_correlations(implied_probability_for_other_edge);
530540
ImpliedWeightUnconverted implied{affected_edge.first, affected_edge.second, w};
531541
edge.implied_weights_for_other_edges.push_back(implied);
532542
}

src/pymatching/sparse_blossom/driver/user_graph.h

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ class UserGraph {
120120
Mwpm& get_mwpm();
121121
Mwpm& get_mwpm_with_search_graph();
122122
void handle_dem_instruction(double p, const std::vector<size_t>& detectors, const std::vector<size_t>& observables);
123+
void handle_dem_instruction_include_correlations(
124+
double p, const std::vector<size_t>& detectors, const std::vector<size_t>& observables);
123125
void get_nodes_on_shortest_path_from_source(size_t src, size_t dst, std::vector<size_t>& out_nodes);
124126
void populate_implied_edge_weights(
125127
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>>& joint_probabilites);
@@ -131,6 +133,8 @@ class UserGraph {
131133
bool _all_edges_have_error_probabilities;
132134
};
133135

136+
double to_weight_for_correlations(double probability);
137+
134138
template <typename EdgeCallable, typename BoundaryEdgeCallable>
135139
inline double UserGraph::iter_discretized_edges(
136140
pm::weight_int num_distinct_weights,
@@ -263,7 +267,8 @@ template <typename Handler>
263267
void iter_dem_instructions_include_correlations(
264268
const stim::DetectorErrorModel& detector_error_model,
265269
const Handler& handle_dem_error,
266-
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>>& joint_probabilites) {
270+
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>>& joint_probabilites,
271+
bool include_decomposed_error_components_in_edge_weights = true) {
267272
detector_error_model.iter_flatten_error_instructions([&](const stim::DemInstruction& instruction) {
268273
double p = instruction.arg_data[0];
269274
pm::DecomposedDemError decomposed_err;
@@ -311,35 +316,23 @@ void iter_dem_instructions_include_correlations(
311316
component->observable_indices.push_back(target.val());
312317
} else if (target.is_separator()) {
313318
instruction_contains_separator = true;
314-
// If the previous error in the decomposition had 3 or more detectors, we throw an exception.
315-
if (num_component_detectors > 2) {
316-
throw std::invalid_argument(
317-
"Encountered a decomposed error instruction with a hyperedge component (3 or more detectors). "
318-
"This is not supported.");
319-
} else if (num_component_detectors == 0) {
319+
// We cannot have num_component_detectors > 2 at this point, or we would have already thrown an
320+
// exception
321+
if (num_component_detectors == 0) {
320322
throw std::invalid_argument(
321323
"Encountered a decomposed error instruction with an undetectable component (0 detectors). "
322324
"This is not supported.");
323-
} else if (num_component_detectors > 0) {
324-
// If the previous error in the decomposition had 1 or 2 detectors, we handle it
325-
handle_dem_error(p, {component->node1, component->node2}, component->observable_indices);
326-
decomposed_err.components.push_back({});
327-
component = &decomposed_err.components.back();
328-
component->node1 = SIZE_MAX;
329-
component->node2 = SIZE_MAX;
330-
num_component_detectors = 0;
331325
}
326+
// The previous error in the decomposition must have 1 or 2 detectors
327+
decomposed_err.components.push_back({});
328+
component = &decomposed_err.components.back();
329+
component->node1 = SIZE_MAX;
330+
component->node2 = SIZE_MAX;
331+
num_component_detectors = 0;
332332
}
333333
}
334334

335-
if (num_component_detectors > 2) {
336-
// Undecomposed hyperedges are not supported
337-
throw std::invalid_argument(
338-
"Encountered an undecomposed error instruction with 3 or mode detectors. "
339-
"This is not supported when using `enable_correlations=True`. "
340-
"Did you forget to set `decompose_errors=True` when "
341-
"converting the stim circuit to a detector error model?");
342-
} else if (num_component_detectors == 0) {
335+
if (num_component_detectors == 0) {
343336
if (instruction_contains_separator) {
344337
throw std::invalid_argument(
345338
"Encountered a decomposed error instruction with an undetectable component (0 detectors). "
@@ -348,12 +341,17 @@ void iter_dem_instructions_include_correlations(
348341
// Ignore errors that are undetectable, provided they are not a component of a decomposed error
349342
return;
350343
}
344+
}
351345

352-
} else if (num_component_detectors > 0) {
353-
if (component->node2 == SIZE_MAX) {
354-
handle_dem_error(p, {component->node1}, component->observable_indices);
355-
} else {
356-
handle_dem_error(p, {component->node1, component->node2}, component->observable_indices);
346+
// If include_decomposed_error_components_in_edge_weights is False, then only add the edge into
347+
// the graph if it is not a component in a decomposed error with more than one component
348+
if (include_decomposed_error_components_in_edge_weights || decomposed_err.components.size() == 1) {
349+
for (pm::UserEdge& component : decomposed_err.components) {
350+
if (component.node2 == SIZE_MAX) {
351+
handle_dem_error(p, {component.node1}, component.observable_indices);
352+
} else {
353+
handle_dem_error(p, {component.node1, component.node2}, component.observable_indices);
354+
}
357355
}
358356
}
359357

src/pymatching/sparse_blossom/driver/user_graph.test.cc

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ TEST(IterDemInstructionsTest, CombinedComplexDem) {
384384
pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities);
385385

386386
ASSERT_EQ(handler.handled_errors.size(), 4);
387-
387+
388388
std::vector<HandledError> expected = {
389389
{0.1, 0, SIZE_MAX, {}}, {0.2, 1, 2, {0}}, {0.4, 8, SIZE_MAX, {}}, {0.4, 9, SIZE_MAX, {1}}};
390390
EXPECT_EQ(handler.handled_errors, expected);
@@ -485,14 +485,6 @@ TEST(UserGraph, PopulateImpliedEdgeWeights) {
485485

486486
graph.populate_implied_edge_weights(joint_probabilities);
487487

488-
auto to_weight = [](double p) {
489-
if (p == 1.0)
490-
return -std::numeric_limits<double>::infinity();
491-
if (p == 0.0)
492-
return std::numeric_limits<double>::infinity();
493-
return std::log((1 - p) / p);
494-
};
495-
496488
auto it_01 = std::find_if(graph.edges.begin(), graph.edges.end(), [](const pm::UserEdge& edge) {
497489
return edge.node1 == 0 && edge.node2 == 1;
498490
});
@@ -503,7 +495,7 @@ TEST(UserGraph, PopulateImpliedEdgeWeights) {
503495
ASSERT_EQ(implied_01.node2, 3);
504496

505497
double p_01 = 0.1 / 0.26;
506-
double w_01 = to_weight(p_01);
498+
double w_01 = pm::to_weight_for_correlations(p_01);
507499
ASSERT_EQ(implied_01.implied_weight, w_01);
508500

509501
auto it_23 = std::find_if(graph.edges.begin(), graph.edges.end(), [](const pm::UserEdge& edge) {
@@ -515,7 +507,7 @@ TEST(UserGraph, PopulateImpliedEdgeWeights) {
515507
const auto& implied_23 = it_23->implied_weights_for_other_edges[0];
516508
ASSERT_EQ(implied_23.node1, 0);
517509
ASSERT_EQ(implied_23.node2, 1);
518-
ASSERT_EQ(implied_23.implied_weight, 0);
510+
ASSERT_NEAR(implied_23.implied_weight, 0.0, 0.00001);
519511
}
520512

521513
TEST(UserGraph, ConvertImpliedWeights) {

tests/matching/decode_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,26 @@ def test_decode_to_edges_with_correlations():
400400
assert np.array_equal(edges, expected_edges)
401401

402402

403+
def test_correlated_matching_handles_single_detector_components():
404+
stim = pytest.importorskip("stim")
405+
p = 0.1
406+
circuit = stim.Circuit.generated(
407+
code_task="surface_code:rotated_memory_x",
408+
distance=5,
409+
rounds=5,
410+
before_round_data_depolarization=p,
411+
)
412+
circ_str = str(circuit).replace(
413+
f"DEPOLARIZE1({p})", f"PAULI_CHANNEL_1(0, {p}, 0)"
414+
)
415+
noisy_circuit = stim.Circuit(circ_str)
416+
dem = noisy_circuit.detector_error_model(
417+
decompose_errors=True, approximate_disjoint_errors=True
418+
)
419+
m = Matching.from_detector_error_model(dem, enable_correlations=True)
420+
assert m.num_detectors > 0
421+
422+
403423
def test_load_from_circuit_with_correlations():
404424
stim = pytest.importorskip("stim")
405425
circuit = stim.Circuit.generated(

0 commit comments

Comments
 (0)