Skip to content

Commit 236fd82

Browse files
committed
Print statements
1 parent c19ba00 commit 236fd82

File tree

1 file changed

+94
-5
lines changed

1 file changed

+94
-5
lines changed

src/axiomatic/pic_helpers.py

Lines changed: 94 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from ipywidgets import interactive, IntSlider # type: ignore
44
from typing import List, Optional
55

6-
from . import Parameter
6+
from . import Parameter, StatementDictionary, StatementValidationDictionary
77

88

99
def plot_circuit(component):
@@ -80,8 +80,8 @@ def plot_single_spectrum(
8080
"""
8181
Plot a single spectrum with vertical and horizontal lines.
8282
"""
83-
hlines = hlines or []
84-
vlines = vlines or []
83+
hlines = hlines
84+
vlines = vlines
8585

8686
plt.clf()
8787
plt.figure(figsize=(10, 5))
@@ -142,8 +142,8 @@ def plot_interactive_spectra(
142142

143143
slider_index = slider_index or list(range(len(spectra[0])))
144144
spectrum_labels = spectrum_labels or [f"Spectrum {i}" for i in range(len(spectra))]
145-
vlines = vlines or []
146-
hlines = hlines or []
145+
vlines = vlines
146+
hlines = hlines
147147

148148
# Function to update the plot
149149
def plot_array(index=0):
@@ -201,3 +201,92 @@ def plot_parameter_history(parameters: List[Parameter], parameter_history: List[
201201
]
202202
)
203203
plt.show()
204+
205+
206+
def print_statements(statements: StatementDictionary, validation: StatementValidationDictionary | None = None):
207+
"""
208+
Print a list of statements in nice readable format.
209+
"""
210+
statements = StatementDictionary(
211+
cost_functions=statements.cost_functions or [],
212+
parameter_constraints=statements.parameter_constraints or [],
213+
structure_constraints=statements.structure_constraints or [],
214+
unformalizable_statements=statements.unformalizable_statements or [],
215+
)
216+
217+
validation = StatementValidationDictionary(
218+
cost_functions=(validation.cost_functions if validation is not None else None) or [None]*len(statements.cost_functions),
219+
parameter_constraints=(validation.parameter_constraints if validation is not None else None) or [None]*len(statements.parameter_constraints),
220+
structure_constraints=(validation.structure_constraints if validation is not None else None) or [None]*len(statements.structure_constraints),
221+
unformalizable_statements=(validation.unformalizable_statements if validation is not None else None) or [None]*len(statements.unformalizable_statements)
222+
)
223+
224+
if len(validation.cost_functions) != len(statements.cost_functions):
225+
raise ValueError("Number of cost functions and validations do not match.")
226+
if len(validation.parameter_constraints) != len(statements.parameter_constraints):
227+
raise ValueError("Number of parameter constraints and validations do not match.")
228+
if len(validation.structure_constraints) != len(statements.structure_constraints):
229+
raise ValueError("Number of structure constraints and validations do not match.")
230+
if len(validation.unformalizable_statements) != len(statements.unformalizable_statements):
231+
raise ValueError("Number of unformalizable statements and validations do not match.")
232+
233+
print("-----------------------------------\n")
234+
for z3_stmt, z3_val in zip((statements.cost_functions) + (statements.parameter_constraints), (validation.cost_functions) + (validation.parameter_constraints)):
235+
print("Type:", z3_stmt.type)
236+
print("Statement:", z3_stmt.text)
237+
print("Formalization:", end=" ")
238+
if z3_stmt.formalization is None:
239+
print("UNFORMALIZED")
240+
else:
241+
code = z3_stmt.formalization.code
242+
if z3_stmt.formalization.mapping is not None:
243+
for var_name, computation in z3_stmt.formalization.mapping.items():
244+
if computation is not None:
245+
args_str = ", ".join(
246+
[
247+
f"{argname}="
248+
+ (f"'{argvalue}'" if isinstance(argvalue, str) else str(argvalue))
249+
for argname, argvalue in computation.arguments.items()
250+
]
251+
)
252+
code = code.replace(var_name, f"{computation.name}({args_str})")
253+
print(code)
254+
val = z3_stmt.validation or z3_val
255+
if val is not None:
256+
if z3_stmt.type == "COST_FUNCTION":
257+
print(f"Satisfiable: {val.satisfiable}")
258+
print(val.message)
259+
else:
260+
print(f"Satisfiable: {val.satisfiable}")
261+
print(f"Holds: {val.holds} ({val.message})")
262+
print("\n-----------------------------------\n")
263+
for struct_stmt, struct_val in zip(statements.structure_constraints, validation.structure_constraints):
264+
print("Type:", struct_stmt.type)
265+
print("Statement:", struct_stmt.text)
266+
print("Formalization:", end=" ")
267+
if struct_stmt.formalization is None:
268+
print("UNFORMALIZED")
269+
else:
270+
func_constr = struct_stmt.formalization
271+
args_str = ", ".join(
272+
[
273+
f"{argname}=" + (f"'{argvalue}'" if isinstance(argvalue, str) else str(argvalue))
274+
for argname, argvalue in func_constr.arguments.items()
275+
]
276+
)
277+
func_str = f"{func_constr.function_name}({args_str}) == {func_constr.expected_result}"
278+
print(func_str)
279+
val = struct_stmt.validation or struct_val
280+
if val is not None:
281+
print(f"Satisfiable: {val.satisfiable}")
282+
print(f"Holds: {val.holds}")
283+
print("\n-----------------------------------\n")
284+
for unf_stmt, unf_val in zip(statements.unformalizable_statements, validation.unformalizable_statements):
285+
print("Type:", unf_stmt.type)
286+
print("Statement:", unf_stmt.text)
287+
print("Formalization: UNFORMALIZABLE")
288+
val = unf_stmt.validation or unf_val
289+
if val is not None:
290+
print(f"Satisfiable: {val.satisfiable}")
291+
print(f"Holds: {val.holds}")
292+
print("\n-----------------------------------\n")

0 commit comments

Comments
 (0)