Skip to content

Commit ac6c2f4

Browse files
committed
Print statements
1 parent c19ba00 commit ac6c2f4

File tree

1 file changed

+83
-1
lines changed

1 file changed

+83
-1
lines changed

src/axiomatic/pic_helpers.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from ipywidgets import interactive, IntSlider # type: ignore
44
from typing import List, Optional
55

6-
from . import Parameter
6+
from . import Parameter, StatementDictionary, StatementValidationDictionary
77

88

99
def plot_circuit(component):
@@ -201,3 +201,85 @@ def plot_parameter_history(parameters: List[Parameter], parameter_history: List[
201201
]
202202
)
203203
plt.show()
204+
205+
206+
def print_statements(statements: StatementDictionary, validation: StatementValidationDictionary | None = None):
207+
"""
208+
Print a list of statements in nice readable format.
209+
"""
210+
validation = validation or StatementValidationDictionary(
211+
cost_functions=[None]*len(statements.cost_functions),
212+
parameter_constraints=[None]*len(statements.parameter_constraints),
213+
structure_constraints=[None]*len(statements.structure_constraints),
214+
unformalizable_statements=[None]*len(statements.unformalizable_statements)
215+
)
216+
217+
if len(validation.cost_functions) != len(statements.cost_functions):
218+
raise ValueError("Number of cost functions and validations do not match.")
219+
if len(validation.parameter_constraints) != len(statements.parameter_constraints):
220+
raise ValueError("Number of parameter constraints and validations do not match.")
221+
if len(validation.structure_constraints) != len(statements.structure_constraints):
222+
raise ValueError("Number of structure constraints and validations do not match.")
223+
if len(validation.unformalizable_statements) != len(statements.unformalizable_statements):
224+
raise ValueError("Number of unformalizable statements and validations do not match.")
225+
226+
print("-----------------------------------\n")
227+
for z3_stmt, z3_val in zip((statements.cost_functions or []) + (statements.parameter_constraints or []), (validation.cost_functions or []) + (validation.parameter_constraints or []), strict=True):
228+
print("Type:", z3_stmt.type)
229+
print("Statement:", z3_stmt.text)
230+
print("Formalization:", end=" ")
231+
if z3_stmt.formalization is None:
232+
print("UNFORMALIZED")
233+
else:
234+
code = z3_stmt.formalization.code
235+
if z3_stmt.formalization.mapping is not None:
236+
for var_name, computation in z3_stmt.formalization.mapping.items():
237+
if computation is not None:
238+
args_str = ", ".join(
239+
[
240+
f"{argname}="
241+
+ (f"'{argvalue}'" if isinstance(argvalue, str) else str(argvalue))
242+
for argname, argvalue in computation.arguments.items()
243+
]
244+
)
245+
code = code.replace(var_name, f"{computation.name}({args_str})")
246+
print(code)
247+
val = z3_stmt.validation or z3_val
248+
if val is not None:
249+
if z3_stmt.type == "COST_FUNCTION":
250+
print(f"Satisfiable: {val.satisfiable}")
251+
print(val.message)
252+
else:
253+
print(f"Satisfiable: {val.satisfiable}")
254+
print(f"Holds: {val.holds} ({val.message})")
255+
print("\n-----------------------------------\n")
256+
for struct_stmt, struct_val in zip(statements.structure_constraints or [], validation.structure_constraints or [], strict=True):
257+
print("Type:", struct_stmt.type)
258+
print("Statement:", struct_stmt.text)
259+
print("Formalization:", end=" ")
260+
if struct_stmt.formalization is None:
261+
print("UNFORMALIZED")
262+
else:
263+
func_constr = struct_stmt.formalization
264+
args_str = ", ".join(
265+
[
266+
f"{argname}=" + (f"'{argvalue}'" if isinstance(argvalue, str) else str(argvalue))
267+
for argname, argvalue in func_constr.arguments.items()
268+
]
269+
)
270+
func_str = f"{func_constr.function_name}({args_str}) == {func_constr.expected_result}"
271+
print(func_str)
272+
val = struct_stmt.validation or struct_val
273+
if val is not None:
274+
print(f"Satisfiable: {val.satisfiable}")
275+
print(f"Holds: {val.holds}")
276+
print("\n-----------------------------------\n")
277+
for unf_stmt, unf_val in zip(statements.unformalizable_statements or [], validation.unformalizable_statements or [], strict=True):
278+
print("Type:", unf_stmt.type)
279+
print("Statement:", unf_stmt.text)
280+
print("Formalization: UNFORMALIZABLE")
281+
val = unf_stmt.validation or unf_val
282+
if val is not None:
283+
print(f"Satisfiable: {val.satisfiable}")
284+
print(f"Holds: {val.holds}")
285+
print("\n-----------------------------------\n")

0 commit comments

Comments
 (0)