diff --git a/crates/pecos-python/src/sparse_sim.rs b/crates/pecos-python/src/sparse_sim.rs index 213b60121..acdcc36e3 100644 --- a/crates/pecos-python/src/sparse_sim.rs +++ b/crates/pecos-python/src/sparse_sim.rs @@ -10,6 +10,8 @@ // or implied. See the License for the specific language governing permissions and limitations under // the License. +#![allow(clippy::useless_conversion)] + use pecos::prelude::*; use pyo3::prelude::*; use pyo3::types::{PyDict, PyTuple}; diff --git a/python/quantum-pecos/src/pecos/classical_interpreters/phir_classical_interpreter.py b/python/quantum-pecos/src/pecos/classical_interpreters/phir_classical_interpreter.py index 955a69376..28e8f6c00 100644 --- a/python/quantum-pecos/src/pecos/classical_interpreters/phir_classical_interpreter.py +++ b/python/quantum-pecos/src/pecos/classical_interpreters/phir_classical_interpreter.py @@ -15,6 +15,7 @@ import warnings from typing import TYPE_CHECKING, Any +import numpy as np from phir.model import PHIRModel from pecos.classical_interpreters.classical_interpreter_abc import ClassicalInterpreter @@ -203,8 +204,22 @@ def get_cval(self, cvar): return self.cenv[cid] def get_bit(self, cvar, idx): - val = self.get_cval(cvar) & (1 << idx) - val >>= idx + cval = self.get_cval(cvar) + dtype = type(cval) + + # Check if idx is within valid range for the data type + bit_width = 8 * np.dtype(dtype).itemsize + if idx >= bit_width: + msg = f"Bit index {idx} out of range for {dtype} (max {bit_width - 1})" + raise ValueError( + msg, + ) + + # Use the same data type for the constant 1 + one = dtype(1) + mask = one << dtype(idx) + + val = (cval & mask) >> dtype(idx) return val def eval_expr(self, expr: int | str | list | pt.opt.COp) -> int | None: diff --git a/python/tests/pecos/unit/test_phir_classical_interpreter.py b/python/tests/pecos/unit/test_phir_classical_interpreter.py new file mode 100644 index 000000000..0b8ebc168 --- /dev/null +++ b/python/tests/pecos/unit/test_phir_classical_interpreter.py @@ -0,0 +1,74 @@ +import numpy as np +import pytest +from pecos.classical_interpreters.phir_classical_interpreter import ( + PHIRClassicalInterpreter, +) + +# Note: This test assumes the get_bit method has been updated to include bounds checking. +# If you're implementing the int() conversion approach instead, this test should be removed. + + +@pytest.fixture +def interpreter(): + """Create and initialize a PHIRClassicalInterpreter with essential test data.""" + interpreter = PHIRClassicalInterpreter() + + # Set up test variables + interpreter.csym2id = { + "u8_var": 0, + "u64_var": 1, + } + + # Test patterns: alternating bits for u8, highest bit set for u64 + interpreter.cenv = [ + np.uint8(0b10101010), # u8_var with alternating bits + np.uint64(0x8000000000000000), # u64_var with only bit 63 set + ] + + interpreter.cid2dtype = [ + np.uint8, + np.uint64, + ] + + return interpreter + + +def test_get_bit_basic_functionality(interpreter): + """Test basic bit retrieval functionality.""" + # Test alternating 0s and 1s in the 8-bit variable + assert interpreter.get_bit("u8_var", 0) == 0 + assert interpreter.get_bit("u8_var", 1) == 1 + assert interpreter.get_bit("u8_var", 7) == 1 + + +def test_get_bit_highest_bit(interpreter): + """Test accessing the highest bit of a 64-bit value, which is most likely to cause issues.""" + # This is the critical test for the potential overflow issue + assert interpreter.get_bit("u64_var", 63) == 1 + + # Verify lower bits are 0 + assert interpreter.get_bit("u64_var", 0) == 0 + assert interpreter.get_bit("u64_var", 62) == 0 + + +def test_get_bit_out_of_bounds(interpreter): + """Test that attempting to access bits beyond the data type width raises an error.""" + # Test with specific error message patterns matching the implementation + with pytest.raises( + ValueError, + match=r"Bit index 8 out of range for.*uint8.* \(max 7\)", + ): + interpreter.get_bit("u8_var", 8) # u8 has bits 0-7 only + + with pytest.raises( + ValueError, + match=r"Bit index 64 out of range for.*uint64.* \(max 63\)", + ): + interpreter.get_bit("u64_var", 64) # u64 has bits 0-63 only + + # Test with an extremely large index + with pytest.raises( + ValueError, + match=r"Bit index 1000 out of range for.*uint64.* \(max 63\)", + ): + interpreter.get_bit("u64_var", 1000)