diff --git a/src/bloqade/stim/emit/stim_str.py b/src/bloqade/stim/emit/stim_str.py index 640d97a1f..354f14f2c 100644 --- a/src/bloqade/stim/emit/stim_str.py +++ b/src/bloqade/stim/emit/stim_str.py @@ -27,6 +27,12 @@ def initialize(self): self.file.seek(0) return self + def run_method( + self, method: ir.Method, args: tuple[str, ...] + ) -> tuple[EmitStrFrame, str]: + self._current_method = method + return super().run_method(method, args) + def eval_stmt_fallback( self, frame: EmitStrFrame, stmt: ir.Statement ) -> tuple[str, ...]: @@ -34,6 +40,7 @@ def eval_stmt_fallback( def emit_block(self, frame: EmitStrFrame, block: ir.Block) -> str | None: for stmt in block.stmts: + frame.current_stmt = stmt result = self.eval_stmt(frame, stmt) if isinstance(result, tuple): frame.set_values(stmt.results, result) @@ -43,6 +50,18 @@ def get_output(self) -> str: self.file.seek(0) return self.file.read() + def writeln(self, frame: EmitStrFrame, *args): + if self.debug: + self.newline(frame) + source = frame.current_stmt.source + if source is not None: + self.file.write( + f"# v-- {source.file}:{source.lineno + self._current_method.lineno_begin -1}" + ) + else: + self.file.write("# v-- unknown source") + super().writeln(frame, *args) + @func.dialect.register(key="emit.stim") class FuncEmit(interp.MethodTable): @@ -50,5 +69,4 @@ class FuncEmit(interp.MethodTable): @interp.impl(func.Function) def emit_func(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: func.Function): _ = emit.run_ssacfg_region(frame, stmt.body, ()) - # emit.output = "\n".join(frame.body) return () diff --git a/test/stim/emit/__init__.py b/test/stim/emit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/stim/emit/test_stim_str.py b/test/stim/emit/test_stim_str.py new file mode 100644 index 000000000..67043a114 --- /dev/null +++ b/test/stim/emit/test_stim_str.py @@ -0,0 +1,22 @@ +import pytest + +from bloqade import stim +from bloqade.stim.emit import EmitStimMain + + +@pytest.mark.parametrize("debug", [True, False]) +def test_debug_emit_with_source_info(debug: bool): + @stim.main + def test(): + stim.cx((0, 1), (2, 3)) + + emit = EmitStimMain(debug=debug) + emit.initialize() + emit.run(mt=test, args=()) + output = emit.get_output() + + if debug: + assert "# v--" in output + assert ".py:" in output + else: + assert "# v--" not in output