Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ print(num_errors) # prints 8

To decode instead with correlated matching, set `enable_correlations=True` both when configuiing the `pymatching.Matching` object:
```python
matching_corr = pymatching.Matching.from_detector_error_model(dem, enable_correlations=True)
matching_corr = pymatching.Matching.from_detector_error_model(model, enable_correlations=True)
```

as well as when decoding:
Expand Down
32 changes: 28 additions & 4 deletions src/pymatching/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,14 +1296,23 @@ def from_detector_error_model(
return m

@staticmethod
def from_detector_error_model_file(dem_path: Union[str, Path]) -> 'pymatching.Matching':
def from_detector_error_model_file(
dem_path: Union[str, Path],
*,
enable_correlations: bool = False
) -> 'pymatching.Matching':
"""
Construct a `pymatching.Matching` by loading from a stim DetectorErrorModel file path.

Parameters
----------
dem_path : str
The path of the detector error model file
enable_correlations : bool, optional
If `enable_correlations==True`, the detector error model is converted into an internal
representation that allows correlated matching to be used. Note that you must set
`enable_correlations=True` here in order to use `enable_correlations=True` when decoding.
By default, False.

Returns
-------
Expand All @@ -1314,7 +1323,10 @@ def from_detector_error_model_file(dem_path: Union[str, Path]) -> 'pymatching.Ma
if isinstance(dem_path, Path):
dem_path = str(dem_path)
m = Matching()
m._matching_graph = _cpp_pm.detector_error_model_file_to_matching_graph(dem_path)
m._matching_graph = _cpp_pm.detector_error_model_file_to_matching_graph(
dem_path,
enable_correlations=enable_correlations
)
return m

@staticmethod
Expand Down Expand Up @@ -1371,7 +1383,11 @@ def from_stim_circuit(circuit: 'stim.Circuit', *, enable_correlations=False) ->
return m

@staticmethod
def from_stim_circuit_file(stim_circuit_path: Union[str, Path]) -> 'pymatching.Matching':
def from_stim_circuit_file(
stim_circuit_path: Union[str, Path],
*,
enable_correlations: bool = False
) -> 'pymatching.Matching':
"""
Construct a `pymatching.Matching` by loading from a stim circuit file path.

Expand All @@ -1386,11 +1402,19 @@ def from_stim_circuit_file(stim_circuit_path: Union[str, Path]) -> 'pymatching.M
A `pymatching.Matching` object representing the graphlike error mechanisms in the stim circuit
in the file `stim_circuit_path`, with any hyperedge error mechanisms decomposed into graphlike error
mechanisms. Parallel edges are merged using `merge_strategy="independent"`.
enable_correlations : bool, optional
If `enable_correlations==True`, the stim circuit's detector error model is converted into an internal
representation that allows correlated matching to be used. Note that you must set
`enable_correlations=True` here in order to use `enable_correlations=True` when decoding.
By default, False.
"""
if isinstance(stim_circuit_path, Path):
stim_circuit_path = str(stim_circuit_path)
m = Matching()
m._matching_graph = _cpp_pm.stim_circuit_file_to_matching_graph(stim_circuit_path)
m._matching_graph = _cpp_pm.stim_circuit_file_to_matching_graph(
stim_circuit_path,
enable_correlations=enable_correlations
)
return m

def _load_from_detector_error_model(self, model: 'stim.DetectorErrorModel', *, enable_correlations: bool = False) -> None:
Expand Down
41 changes: 29 additions & 12 deletions src/pymatching/sparse_blossom/driver/user_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ struct DecomposedDemError {
/// The probability of this error occurring.
double probability;
/// Effects of the error.
stim::FixedCapVector<UserEdge, 8> components;
std::vector<UserEdge> components;

bool operator==(const DecomposedDemError& other) const;
bool operator!=(const DecomposedDemError& other) const;
Expand Down Expand Up @@ -283,6 +283,7 @@ void iter_dem_instructions_include_correlations(
component->node1 = SIZE_MAX;
component->node2 = SIZE_MAX;
size_t num_component_detectors = 0;
bool instruction_contains_separator = false;
for (auto& target : instruction.target_data) {
// Decompose error
if (target.is_relative_detector_id()) {
Expand All @@ -309,30 +310,46 @@ void iter_dem_instructions_include_correlations(
} else if (target.is_observable_id()) {
component->observable_indices.push_back(target.val());
} else if (target.is_separator()) {
// If the previous error in the decomposition had 3 or more components, we ignore it.
if (component->node1 == SIZE_MAX) {
instruction_contains_separator = true;
// If the previous error in the decomposition had 3 or more detectors, we throw an exception.
if (num_component_detectors > 2) {
throw std::invalid_argument(
"Encountered a decomposed error instruction with a hyperedge component (3 or more detectors). "
"This is not supported.");
} else if (p > 0) {
} else if (num_component_detectors == 0) {
throw std::invalid_argument(
"Encountered a decomposed error instruction with an undetectable component (0 detectors). "
"This is not supported.");
} else if (num_component_detectors > 0) {
// If the previous error in the decomposition had 1 or 2 detectors, we handle it
handle_dem_error(p, {component->node1, component->node2}, component->observable_indices);
decomposed_err.components.push_back({});
component = &decomposed_err.components.back();
component->node1 = SIZE_MAX;
component->node2 = SIZE_MAX;
num_component_detectors = 0;
}
decomposed_err.components.push_back({});
component = &decomposed_err.components.back();
component->node1 = SIZE_MAX;
component->node2 = SIZE_MAX;
num_component_detectors = 0;
}
}
// If the final error in the decomposition had 3 or more components, we ignore it.
if (component->node1 == SIZE_MAX) {

if (num_component_detectors > 2) {
// Undecomposed hyperedges are not supported
throw std::invalid_argument(
"Encountered an undecomposed error instruction with 3 or mode detectors. "
"This is not supported when using `enable_correlations=True`. "
"Did you forget to set `decompose_errors=True` when "
"converting the stim circuit to a detector error model?");
} else if (p > 0) {
} else if (num_component_detectors == 0) {
if (instruction_contains_separator) {
throw std::invalid_argument(
"Encountered a decomposed error instruction with an undetectable component (0 detectors). "
"This is not supported.");
} else {
// Ignore errors that are undetectable, provided they are not a component of a decomposed error
return;
}

} else if (num_component_detectors > 0) {
if (component->node2 == SIZE_MAX) {
handle_dem_error(p, {component->node1}, component->observable_indices);
} else {
Expand Down
31 changes: 28 additions & 3 deletions src/pymatching/sparse_blossom/driver/user_graph.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ TEST(IterDemInstructionsTest, ThreeDetectorErrorThrowsInvalidArgument) {
stim::DetectorErrorModel dem("error(0.1) D0 D1 D2");
TestHandler handler;
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>> joint_probabilities;
ASSERT_THROW(pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities), std::invalid_argument);
ASSERT_THROW(
pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities), std::invalid_argument);
}

// Test a decomposed error instruction. The handler should be called for each component.
Expand Down Expand Up @@ -353,13 +354,30 @@ TEST(IterDemInstructionsTest, DecomposedErrorWithHyperedgeThrows) {
pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities), std::invalid_argument);
}

// Test that a decomposed error with an undetectable component throws an exception.
TEST(IterDemInstructionsTest, DecomposedErrorWithUndetectableErrorThrows) {
stim::DetectorErrorModel dem("error(0.15) L0 ^ D2 D4 ^ D5 D6 L2");
TestHandler handler;
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>> joint_probabilities;

// Assert that the function throws std::invalid_argument when processing the DEM.
ASSERT_THROW(
pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities), std::invalid_argument);

stim::DetectorErrorModel dem2("error(0.15) D2 D4 ^ D5 D6 L2 ^ L1");
// Assert that the function throws std::invalid_argument when processing the DEM.
ASSERT_THROW(
pm::iter_dem_instructions_include_correlations(dem2, handler, joint_probabilities), std::invalid_argument);
}

// Test a complex DEM with multiple instruction types and edge cases combined.
TEST(IterDemInstructionsTest, CombinedComplexDem) {
stim::DetectorErrorModel dem(R"DEM(
error(0.1) D0 # Instruction 1: Simple
error(0.3) L0 # Instruction 2: Undetectable error, ignored
error(0.2) D1 D2 L0 # Instruction 2: Two detectors, one observable
error(0.0) D7 # Instruction 4: Zero probability, ignored
error(0.4) D8 ^ D9 L1 # Instruction 5: Decomposed
error(0.0) D7 # Instruction 3: Zero probability, ignored
error(0.4) D8 ^ D9 L1 # Instruction 4: Decomposed
)DEM");
TestHandler handler;
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>> joint_probabilities;
Expand Down Expand Up @@ -395,6 +413,13 @@ double bernoulli_xor(double p1, double p2) {
return p1 * (1 - p2) + p2 * (1 - p1);
}

TEST(IterDemInstructionsTest, MoreThanEightComponents) {
stim::DetectorErrorModel dem("error(0.1) D0 ^ D1 ^ D2 ^ D3 ^ D4 ^ D5 ^ D6 ^ D7 ^ D8");
TestHandler handler;
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>> joint_probabilities;
pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities);
}

// Tests that multiple error instructions on the same edge correctly combine their probabilities.
TEST(IterDemInstructionsTest, MultipleErrorsOnSameEdgeCombine) {
stim::DetectorErrorModel dem("error(0.1) D0 D1\n error(0.2) D0 D1");
Expand Down
36 changes: 26 additions & 10 deletions tests/matching/decode_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,22 +387,32 @@ def test_load_from_circuit_with_correlations():
predictions, weights = m.decode_batch(shots=shots, return_weights=True, enable_correlations=True)


def test_use_correlations_with_uncorrelated_dem_load_raises_value_error():
def test_use_correlations_with_uncorrelated_dem_load_raises_value_error(tmp_path):
stim = pytest.importorskip("stim")
d = 3
p = 0.001
circuit = stim.Circuit.generated(
code_task="surface_code:rotated_memory_x",
distance=3,
rounds=3,
after_clifford_depolarization=0.001
distance=d,
rounds=d,
after_clifford_depolarization=p
)
dem = circuit.detector_error_model(decompose_errors=True)
shots = circuit.compile_detector_sampler().sample(shots=10)
matching_1 = pymatching.Matching(circuit, enable_correlations=False)
matching_2 = pymatching.Matching.from_stim_circuit(circuit=circuit, enable_correlations=False)
matching_3 = pymatching.Matching.from_detector_error_model(
model=circuit.detector_error_model(decompose_errors=True),
model=dem,
enable_correlations=False
)
for m in (matching_1, matching_2, matching_3):
fn = f"surface_code_x_d{d}_r{d}_p{p}"
stim_file = tmp_path / f"{fn}.stim"
circuit.to_file(stim_file)
matching_4 = pymatching.Matching.from_stim_circuit_file(stim_file, enable_correlations=False)
dem_file = tmp_path / f"{fn}.dem"
dem.to_file(dem_file)
matching_5 = pymatching.Matching.from_detector_error_model_file(dem_file, enable_correlations=False)
for m in (matching_1, matching_2, matching_3, matching_4, matching_5):
with pytest.raises(ValueError):
m.decode_batch(shots=shots, return_weights=True, enable_correlations=True)
with pytest.raises(ValueError):
Expand All @@ -413,16 +423,22 @@ def test_use_correlations_with_uncorrelated_dem_load_raises_value_error():
m.decode(shots[0], enable_correlations=True)


def test_use_correlations_without_decompose_errors_raises_value_error():
def test_use_correlations_without_decompose_errors_raises_value_error(tmp_path):
stim = pytest.importorskip("stim")
d = 3
p = 0.001
circuit = stim.Circuit.generated(
code_task="surface_code:rotated_memory_x",
distance=3,
rounds=3,
after_clifford_depolarization=0.001
distance=d,
rounds=d,
after_clifford_depolarization=p
)
dem = circuit.detector_error_model(decompose_errors=False)
dem_file = tmp_path / "surface_code.dem"
dem.to_file(dem_file)
with pytest.raises(ValueError):
pymatching.Matching.from_detector_error_model(dem, enable_correlations=True)
with pytest.raises(ValueError):
pymatching.Matching(dem, enable_correlations=True)
with pytest.raises(ValueError):
pymatching.Matching.from_detector_error_model_file(dem_file, enable_correlations=True)
Loading