diff --git a/test/analysis/measure_id/test_measure_id.py b/test/analysis/measure_id/test_measure_id.py index b3041fa6..f933a9ad 100644 --- a/test/analysis/measure_id/test_measure_id.py +++ b/test/analysis/measure_id/test_measure_id.py @@ -1,10 +1,11 @@ -import pytest -from kirin.passes import HintConst, inline from kirin.dialects import scf +from kirin.passes.inline import InlinePass from bloqade import squin from bloqade.analysis.measure_id import MeasurementIDAnalysis +from bloqade.stim.passes.flatten import Flatten from bloqade.analysis.measure_id.lattice import ( + NotMeasureId, MeasureIdBool, MeasureIdTuple, InvalidMeasureId, @@ -15,7 +16,16 @@ def results_at(kern, block_id, stmt_id): return kern.code.body.blocks[block_id].stmts.at(stmt_id).results # type: ignore -@pytest.mark.xfail +def results_of_variables(kernel, variable_names): + results = {} + for stmt in kernel.callable_region.stmts(): + for result in stmt.results: + if result.name in variable_names: + results[result.name] = result + + return results + + def test_add(): @squin.kernel def test(): @@ -28,6 +38,8 @@ def test(): ml2 = squin.broadcast.measure(ql2) return ml1 + ml2 + Flatten(test.dialects).fixpoint(test) + frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test) measure_id_tuples = [ @@ -41,7 +53,6 @@ def test(): assert measure_id_tuples[-1] == expected_measure_id_tuple -@pytest.mark.xfail def test_measure_alias(): @squin.kernel @@ -52,28 +63,33 @@ def test(): return ml_alias + Flatten(test.dialects).fixpoint(test) frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test) - test.print(analysis=frame.entries) - # Collect MeasureIdTuples measure_id_tuples = [ value for value in frame.entries.values() if isinstance(value, MeasureIdTuple) ] - # construct expected MeasureIdTuple - expected_measure_id_tuple = MeasureIdTuple( + # construct expected MeasureIdTuples + measure_id_tuple_with_id_bools = MeasureIdTuple( data=tuple([MeasureIdBool(idx=i) for i in range(1, 6)]) ) + measure_id_tuple_with_not_measures = MeasureIdTuple( + data=tuple([NotMeasureId() for _ in range(5)]) + ) - assert len(measure_id_tuples) == 2 + assert len(measure_id_tuples) == 3 + # New qubit.new semantics cause a MeasureIdTuple to be generated full of NotMeasureIds because + # qubit.new is actually an ilist.map that invokes single qubit allocation multiple times + # and puts them into an ilist. + assert measure_id_tuples[0] == measure_id_tuple_with_not_measures assert all( - measure_id_tuple == expected_measure_id_tuple - for measure_id_tuple in measure_id_tuples + measure_id_tuple == measure_id_tuple_with_id_bools + for measure_id_tuple in measure_id_tuples[1:] ) -@pytest.mark.xfail def test_measure_count_at_if_else(): @squin.kernel @@ -88,6 +104,7 @@ def test(): if ms[3]: squin.y(q[1]) + Flatten(test.dialects).fixpoint(test) frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test) assert all( @@ -96,32 +113,29 @@ def test(): ) -@pytest.mark.xfail def test_scf_cond_true(): @squin.kernel def test(): - q = squin.qalloc(1) + q = squin.qalloc(3) squin.x(q[2]) ms = None cond = True if cond: - ms = squin.broadcast.measure(q) + ms = squin.measure(q[1]) else: ms = squin.measure(q[0]) return ms - HintConst(dialects=test.dialects).unsafe_run(test) + InlinePass(test.dialects).fixpoint(test) frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test) - # MeasureIdTuple(data=MeasureIdBool(idx=1),) should occur twice: + # MeasureIdBool(idx=1) should occur twice: # First from the measurement in the true branch, then # the result of the scf.IfElse itself analysis_results = [ - val - for val in frame.entries.values() - if val == MeasureIdTuple(data=(MeasureIdBool(idx=1),)) + val for val in frame.entries.values() if val == MeasureIdBool(idx=1) ] assert len(analysis_results) == 2 @@ -136,16 +150,16 @@ def test(): ms = None cond = False if cond: - ms = squin.broadcast.measure(q) + ms = squin.measure(q[1]) else: - ms = squin.qubit.measure(q[0]) + ms = squin.measure(q[0]) return ms - inline.InlinePass(test.dialects).fixpoint(test) - - HintConst(dialects=test.dialects).unsafe_run(test) + # need to preserve the scf.IfElse but need things like qalloc to be inlined + InlinePass(test.dialects).fixpoint(test) frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test) + test.print(analysis=frame.entries) # MeasureIdBool(idx=1) should occur twice: # First from the measurement in the false branch, then @@ -156,7 +170,37 @@ def test(): assert len(analysis_results) == 2 -@pytest.mark.xfail +def test_scf_cond_unknown(): + + @squin.kernel + def test(cond: bool): + q = squin.qalloc(5) + squin.x(q[2]) + + if cond: + ms = squin.broadcast.measure(q) + else: + ms = squin.measure(q[0]) + + return ms + + # We can use Flatten here because the variable condition for the scf.IfElse + # means it cannot be simplified. + Flatten(test.dialects).fixpoint(test) + frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test) + analysis_results = [ + val for val in frame.entries.values() if isinstance(val, MeasureIdTuple) + ] + # Both branches of the scf.IfElse should be properly traversed and contain the following + # analysis results. + expected_full_register_measurement = MeasureIdTuple( + data=tuple([MeasureIdBool(idx=i) for i in range(1, 6)]) + ) + expected_else_measurement = MeasureIdTuple(data=(MeasureIdBool(idx=6),)) + assert expected_full_register_measurement in analysis_results + assert expected_else_measurement in analysis_results + + def test_slice(): @squin.kernel def test(): @@ -170,19 +214,23 @@ def test(): return ms_final + Flatten(test.dialects).fixpoint(test) frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test) - test.print(analysis=frame.entries) + results = results_of_variables(test, ("msi", "msi2", "ms_final")) - assert [frame.entries[result] for result in results_at(test, 0, 7)] == [ - MeasureIdTuple(data=tuple(list(MeasureIdBool(idx=i) for i in range(2, 7)))) - ] - assert [frame.entries[result] for result in results_at(test, 0, 9)] == [ - MeasureIdTuple(data=tuple(list(MeasureIdBool(idx=i) for i in range(3, 7)))) - ] - assert [frame.entries[result] for result in results_at(test, 0, 11)] == [ - MeasureIdTuple(data=(MeasureIdBool(idx=3), MeasureIdBool(idx=5))) - ] + # This is an assertion against `msi` NOT the initial list of measurements + assert frame.get(results["msi"]) == MeasureIdTuple( + data=tuple(list(MeasureIdBool(idx=i) for i in range(2, 7))) + ) + # msi2 + assert frame.get(results["msi2"]) == MeasureIdTuple( + data=tuple(list(MeasureIdBool(idx=i) for i in range(3, 7))) + ) + # ms_final + assert frame.get(results["ms_final"]) == MeasureIdTuple( + data=(MeasureIdBool(idx=3), MeasureIdBool(idx=5)) + ) def test_getitem_no_hint():