Skip to content

Commit 0dce487

Browse files
committed
axtract.py completely changed
1 parent 7c3f7f2 commit 0dce487

File tree

1 file changed

+145
-42
lines changed

1 file changed

+145
-42
lines changed

src/axiomatic/axtract.py

Lines changed: 145 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import ipywidgets as widgets # type: ignore
2-
from IPython.display import display # type: ignore
2+
from IPython.display import display, Math # type: ignore
33
from dataclasses import dataclass, field
4+
import hypernetx as hnx
5+
import matplotlib.pyplot as plt
6+
import re
7+
from ax_core.utils import printing
48

59
OPTION_LIST = {
610
"Select a template": [],
@@ -18,7 +22,7 @@
1822
"Pixel size (multispectral)",
1923
"Swath width",
2024
],
21-
"PAYLOAD": [
25+
"PAYLOAD": [
2226
"Resolution (panchromatic)",
2327
"Ground sampling distance (panchromatic)",
2428
"Resolution (multispectral)",
@@ -117,15 +121,14 @@ def requirements_from_table(results, variable_dict):
117121
name = key
118122
numerical_value = value["Value"]
119123
unit = value["Units"]
120-
tolerance = value["Tolerance"]
121124

122125
requirements.append(
123126
Requirement(
124127
requirement_name=name,
125128
latex_symbol=latex_symbol,
126129
value=numerical_value,
127130
units=unit,
128-
tolerance=tolerance,
131+
tolerance=0.0,
129132
)
130133
)
131134

@@ -178,9 +181,15 @@ def display_table(change):
178181

179182
if selected_option in preset_options_dict:
180183
rows = preset_options_dict[selected_option]
181-
max_name_length = max(len(name) for name in rows)
182-
# Update the name_label_width based on the longest row name
183-
name_label_width[0] = f"{max_name_length + 2}ch"
184+
185+
if selected_option != "Select a template":
186+
max_name_length = max(len(name) for name in rows)
187+
# Update the name_label_width based on the longest row name
188+
name_label_width[0] = f"{max_name_length + 2}ch"
189+
else:
190+
max_name_length = 40
191+
# Update the name_label_width based on the longest row name
192+
name_label_width[0] = f"{max_name_length + 2}ch"
184193

185194
# Add Headers
186195
header_labels = [
@@ -194,16 +203,6 @@ def display_table(change):
194203
layout=widgets.Layout(width="150px"),
195204
style={'font_weight': 'bold'}
196205
),
197-
widgets.Label(
198-
value="Tolerance",
199-
layout=widgets.Layout(width="150px"),
200-
style={'font_weight': 'bold'}
201-
),
202-
widgets.Label(
203-
value="Accuracy",
204-
layout=widgets.Layout(width="150px"),
205-
style={'font_weight': 'bold'}
206-
),
207206
widgets.Label(
208207
value="Units",
209208
layout=widgets.Layout(width="150px"),
@@ -216,7 +215,6 @@ def display_table(change):
216215
header.layout = widgets.Layout(
217216
border='1px solid black',
218217
padding='5px',
219-
background_color='#f0f0f0'
220218
)
221219

222220
# Add the header to the rows_output VBox
@@ -244,28 +242,19 @@ def display_table(change):
244242

245243
# Create input widgets
246244
value_text = widgets.FloatText(
247-
placeholder="Value",
248245
value=default_value,
249246
layout=widgets.Layout(width="150px"),
250247
)
251-
tolerance_text = widgets.FloatText(
252-
placeholder="Tolerance", layout=widgets.Layout(width="150px")
253-
)
254-
accuracy_text = widgets.FloatText(
255-
placeholder="Accuracy", layout=widgets.Layout(width="150px")
256-
)
257248
units_text = widgets.Text(
258-
placeholder="Units", layout=widgets.Layout(width="150px"),
259-
value = default_unit
249+
layout=widgets.Layout(width="150px"),
250+
value=default_unit
260251
)
261252

262253
# Combine widgets into a horizontal box
263254
row = widgets.HBox(
264255
[
265256
name_label,
266257
value_text,
267-
tolerance_text,
268-
accuracy_text,
269258
units_text,
270259
]
271260
)
@@ -291,16 +280,12 @@ def submit_values(_):
291280
if key.startswith("req_"):
292281
updated_values[variable] = {
293282
"Value": widget.children[1].value,
294-
"Tolerance": widget.children[2].value,
295-
"Accuracy": widget.children[3].value,
296-
"Units": widget.children[4].value,
283+
"Units": widget.children[2].value,
297284
}
298285
else:
299286
updated_values[key] = {
300287
"Value": widget.children[1].value,
301-
"Tolerance": widget.children[2].value,
302-
"Accuracy": widget.children[3].value,
303-
"Units": widget.children[4].value,
288+
"Units": widget.children[2].value,
304289
}
305290

306291
result["values"] = updated_values
@@ -327,18 +312,13 @@ def add_req(_):
327312
placeholder="Value",
328313
layout=widgets.Layout(width="150px"),
329314
)
330-
tolerance_text = widgets.FloatText(
331-
placeholder="Tolerance", layout=widgets.Layout(width="150px")
332-
)
333-
accuracy_text = widgets.FloatText(
334-
placeholder="Accuracy", layout=widgets.Layout(width="150px")
335-
)
315+
336316
units_text = widgets.Text(
337317
placeholder="Units", layout=widgets.Layout(width="150px")
338318
)
339319

340320
new_row = widgets.HBox(
341-
[variable_dropdown, value_text, tolerance_text, accuracy_text, units_text]
321+
[variable_dropdown, value_text, units_text]
342322
)
343323

344324
rows_output.children += (new_row,)
@@ -354,3 +334,126 @@ def add_req(_):
354334
display(buttons_box)
355335

356336
return result
337+
338+
339+
def display_formatted_answers(equations_dict):
340+
"""
341+
Display LaTeX formatted equations and numerical results from a nested
342+
dictionary structure in Jupyter Notebook.
343+
344+
Parameters:
345+
equations_dict (dict): The dictionary containing the equations.
346+
"""
347+
results = equations_dict.get('results', {})
348+
print("We identified the following equations that are relevant to your requirements:")
349+
350+
for key, value in results.items():
351+
latex_equation = value.get('latex_equation')
352+
lhs = value.get('lhs')
353+
rhs = value.get('rhs')
354+
match = value.get('match')
355+
if latex_equation:
356+
display(Math(latex_equation))
357+
print(f"For provided values:\nleft hand side = {lhs}\nright hand side = {rhs}")
358+
if match:
359+
print("Provided requirements fulfill this mathematical relation")
360+
else:
361+
print(f"No LaTeX equation found for {key}")
362+
363+
364+
def display_results(equations_dict):
365+
366+
results = equations_dict.get('results', {})
367+
not_match_counter = 0
368+
369+
for key, value in results.items():
370+
match = value.get('match')
371+
latex_equation = value.get('latex_equation')
372+
lhs = value.get('lhs')
373+
rhs = value.get('rhs')
374+
if not match:
375+
not_match_counter += 1
376+
print(
377+
printing.print_color.bold(
378+
printing.print_color.red(
379+
"Provided requirements DO NOT fulfill the following mathematical relation:"
380+
)
381+
)
382+
)
383+
display(Math(latex_equation))
384+
print(f"For provided values:\nleft hand side = {lhs}\nright hand side = {rhs}")
385+
if not_match_counter == 0:
386+
print(printing.print_color.green("Requirements you provided do not cause any conflicts"))
387+
388+
389+
def _get_latex_string_format(input_string):
390+
"""
391+
Properly formats LaTeX strings for matplotlib when text.usetex is False.
392+
No escaping needed since mathtext handles backslashes properly.
393+
"""
394+
return f"${input_string}$" # No backslash escaping required
395+
396+
397+
def _get_requirements_set(requirements):
398+
variable_set = set()
399+
for req in requirements:
400+
variable_set.add(req['latex_symbol'])
401+
402+
return variable_set
403+
404+
405+
def _find_vars_in_eq(equation, variable_set):
406+
patterns = [re.escape(var) for var in variable_set]
407+
combined_pattern = r'|'.join(patterns)
408+
matches = re.findall(combined_pattern, equation)
409+
return {fr"${match}$" for match in matches}
410+
411+
412+
def _add_used_vars_to_results(api_results, api_requirements):
413+
requirements = _get_requirements_set(api_requirements)
414+
415+
for key, value in api_results['results'].items():
416+
latex_equation = value.get('latex_equation')
417+
# print(latex_equation)
418+
if latex_equation:
419+
used_vars = _find_vars_in_eq(latex_equation, requirements)
420+
api_results['results'][key]['used_vars'] = used_vars
421+
422+
return api_results
423+
424+
425+
def get_eq_hypergraph(api_results, api_requirements):
426+
427+
# Disable external LaTeX rendering, using matplotlib's mathtext instead
428+
plt.rcParams['text.usetex'] = False
429+
plt.rcParams['mathtext.fontset'] = 'stix'
430+
plt.rcParams['font.family'] = 'serif'
431+
432+
api_results = _add_used_vars_to_results(api_results, api_requirements)
433+
434+
# Prepare the data for HyperNetX visualization
435+
hyperedges = {}
436+
for eq, details in api_results["results"].items():
437+
hyperedges[_get_latex_string_format(
438+
details["latex_equation"])] = details["used_vars"]
439+
440+
# Create the hypergraph using HyperNetX
441+
H = hnx.Hypergraph(hyperedges)
442+
443+
# Plot the hypergraph with enhanced clarity
444+
plt.figure(figsize=(16, 12))
445+
446+
# Draw the hypergraph with node and edge labels
447+
hnx.draw(
448+
H,
449+
with_edge_labels=True,
450+
edge_labels_on_edge=False,
451+
node_labels_kwargs={'fontsize': 14},
452+
edge_labels_kwargs={'fontsize': 14},
453+
layout_kwargs={'seed': 42, 'scale': 2.5}
454+
)
455+
456+
plt.title(r"Enhanced Hypergraph of Equations and Variables", fontsize=20)
457+
plt.show()
458+
459+

0 commit comments

Comments
 (0)