11import ipywidgets as widgets # type: ignore
2- from IPython .display import display # type: ignore
2+ from IPython .display import display , Math # type: ignore
33from dataclasses import dataclass , field
4+ import hypernetx as hnx
5+ import matplotlib .pyplot as plt
6+ import re
7+ from ax_core .utils import printing
48
59OPTION_LIST = {
610 "Select a template" : [],
1822 "Pixel size (multispectral)" ,
1923 "Swath width" ,
2024 ],
21- "PAYLOAD" : [
25+ "PAYLOAD" : [
2226 "Resolution (panchromatic)" ,
2327 "Ground sampling distance (panchromatic)" ,
2428 "Resolution (multispectral)" ,
@@ -117,15 +121,14 @@ def requirements_from_table(results, variable_dict):
117121 name = key
118122 numerical_value = value ["Value" ]
119123 unit = value ["Units" ]
120- tolerance = value ["Tolerance" ]
121124
122125 requirements .append (
123126 Requirement (
124127 requirement_name = name ,
125128 latex_symbol = latex_symbol ,
126129 value = numerical_value ,
127130 units = unit ,
128- tolerance = tolerance ,
131+ tolerance = 0.0 ,
129132 )
130133 )
131134
@@ -178,9 +181,15 @@ def display_table(change):
178181
179182 if selected_option in preset_options_dict :
180183 rows = preset_options_dict [selected_option ]
181- max_name_length = max (len (name ) for name in rows )
182- # Update the name_label_width based on the longest row name
183- name_label_width [0 ] = f"{ max_name_length + 2 } ch"
184+
185+ if selected_option != "Select a template" :
186+ max_name_length = max (len (name ) for name in rows )
187+ # Update the name_label_width based on the longest row name
188+ name_label_width [0 ] = f"{ max_name_length + 2 } ch"
189+ else :
190+ max_name_length = 40
191+ # Update the name_label_width based on the longest row name
192+ name_label_width [0 ] = f"{ max_name_length + 2 } ch"
184193
185194 # Add Headers
186195 header_labels = [
@@ -194,16 +203,6 @@ def display_table(change):
194203 layout = widgets .Layout (width = "150px" ),
195204 style = {'font_weight' : 'bold' }
196205 ),
197- widgets .Label (
198- value = "Tolerance" ,
199- layout = widgets .Layout (width = "150px" ),
200- style = {'font_weight' : 'bold' }
201- ),
202- widgets .Label (
203- value = "Accuracy" ,
204- layout = widgets .Layout (width = "150px" ),
205- style = {'font_weight' : 'bold' }
206- ),
207206 widgets .Label (
208207 value = "Units" ,
209208 layout = widgets .Layout (width = "150px" ),
@@ -216,7 +215,6 @@ def display_table(change):
216215 header .layout = widgets .Layout (
217216 border = '1px solid black' ,
218217 padding = '5px' ,
219- background_color = '#f0f0f0'
220218 )
221219
222220 # Add the header to the rows_output VBox
@@ -244,28 +242,19 @@ def display_table(change):
244242
245243 # Create input widgets
246244 value_text = widgets .FloatText (
247- placeholder = "Value" ,
248245 value = default_value ,
249246 layout = widgets .Layout (width = "150px" ),
250247 )
251- tolerance_text = widgets .FloatText (
252- placeholder = "Tolerance" , layout = widgets .Layout (width = "150px" )
253- )
254- accuracy_text = widgets .FloatText (
255- placeholder = "Accuracy" , layout = widgets .Layout (width = "150px" )
256- )
257248 units_text = widgets .Text (
258- placeholder = "Units" , layout = widgets .Layout (width = "150px" ),
259- value = default_unit
249+ layout = widgets .Layout (width = "150px" ),
250+ value = default_unit
260251 )
261252
262253 # Combine widgets into a horizontal box
263254 row = widgets .HBox (
264255 [
265256 name_label ,
266257 value_text ,
267- tolerance_text ,
268- accuracy_text ,
269258 units_text ,
270259 ]
271260 )
@@ -291,16 +280,12 @@ def submit_values(_):
291280 if key .startswith ("req_" ):
292281 updated_values [variable ] = {
293282 "Value" : widget .children [1 ].value ,
294- "Tolerance" : widget .children [2 ].value ,
295- "Accuracy" : widget .children [3 ].value ,
296- "Units" : widget .children [4 ].value ,
283+ "Units" : widget .children [2 ].value ,
297284 }
298285 else :
299286 updated_values [key ] = {
300287 "Value" : widget .children [1 ].value ,
301- "Tolerance" : widget .children [2 ].value ,
302- "Accuracy" : widget .children [3 ].value ,
303- "Units" : widget .children [4 ].value ,
288+ "Units" : widget .children [2 ].value ,
304289 }
305290
306291 result ["values" ] = updated_values
@@ -327,18 +312,13 @@ def add_req(_):
327312 placeholder = "Value" ,
328313 layout = widgets .Layout (width = "150px" ),
329314 )
330- tolerance_text = widgets .FloatText (
331- placeholder = "Tolerance" , layout = widgets .Layout (width = "150px" )
332- )
333- accuracy_text = widgets .FloatText (
334- placeholder = "Accuracy" , layout = widgets .Layout (width = "150px" )
335- )
315+
336316 units_text = widgets .Text (
337317 placeholder = "Units" , layout = widgets .Layout (width = "150px" )
338318 )
339319
340320 new_row = widgets .HBox (
341- [variable_dropdown , value_text , tolerance_text , accuracy_text , units_text ]
321+ [variable_dropdown , value_text , units_text ]
342322 )
343323
344324 rows_output .children += (new_row ,)
@@ -354,3 +334,126 @@ def add_req(_):
354334 display (buttons_box )
355335
356336 return result
337+
338+
339+ def display_formatted_answers (equations_dict ):
340+ """
341+ Display LaTeX formatted equations and numerical results from a nested
342+ dictionary structure in Jupyter Notebook.
343+
344+ Parameters:
345+ equations_dict (dict): The dictionary containing the equations.
346+ """
347+ results = equations_dict .get ('results' , {})
348+ print ("We identified the following equations that are relevant to your requirements:" )
349+
350+ for key , value in results .items ():
351+ latex_equation = value .get ('latex_equation' )
352+ lhs = value .get ('lhs' )
353+ rhs = value .get ('rhs' )
354+ match = value .get ('match' )
355+ if latex_equation :
356+ display (Math (latex_equation ))
357+ print (f"For provided values:\n left hand side = { lhs } \n right hand side = { rhs } " )
358+ if match :
359+ print ("Provided requirements fulfill this mathematical relation" )
360+ else :
361+ print (f"No LaTeX equation found for { key } " )
362+
363+
364+ def display_results (equations_dict ):
365+
366+ results = equations_dict .get ('results' , {})
367+ not_match_counter = 0
368+
369+ for key , value in results .items ():
370+ match = value .get ('match' )
371+ latex_equation = value .get ('latex_equation' )
372+ lhs = value .get ('lhs' )
373+ rhs = value .get ('rhs' )
374+ if not match :
375+ not_match_counter += 1
376+ print (
377+ printing .print_color .bold (
378+ printing .print_color .red (
379+ "Provided requirements DO NOT fulfill the following mathematical relation:"
380+ )
381+ )
382+ )
383+ display (Math (latex_equation ))
384+ print (f"For provided values:\n left hand side = { lhs } \n right hand side = { rhs } " )
385+ if not_match_counter == 0 :
386+ print (printing .print_color .green ("Requirements you provided do not cause any conflicts" ))
387+
388+
389+ def _get_latex_string_format (input_string ):
390+ """
391+ Properly formats LaTeX strings for matplotlib when text.usetex is False.
392+ No escaping needed since mathtext handles backslashes properly.
393+ """
394+ return f"${ input_string } $" # No backslash escaping required
395+
396+
397+ def _get_requirements_set (requirements ):
398+ variable_set = set ()
399+ for req in requirements :
400+ variable_set .add (req ['latex_symbol' ])
401+
402+ return variable_set
403+
404+
405+ def _find_vars_in_eq (equation , variable_set ):
406+ patterns = [re .escape (var ) for var in variable_set ]
407+ combined_pattern = r'|' .join (patterns )
408+ matches = re .findall (combined_pattern , equation )
409+ return {fr"${ match } $" for match in matches }
410+
411+
412+ def _add_used_vars_to_results (api_results , api_requirements ):
413+ requirements = _get_requirements_set (api_requirements )
414+
415+ for key , value in api_results ['results' ].items ():
416+ latex_equation = value .get ('latex_equation' )
417+ # print(latex_equation)
418+ if latex_equation :
419+ used_vars = _find_vars_in_eq (latex_equation , requirements )
420+ api_results ['results' ][key ]['used_vars' ] = used_vars
421+
422+ return api_results
423+
424+
425+ def get_eq_hypergraph (api_results , api_requirements ):
426+
427+ # Disable external LaTeX rendering, using matplotlib's mathtext instead
428+ plt .rcParams ['text.usetex' ] = False
429+ plt .rcParams ['mathtext.fontset' ] = 'stix'
430+ plt .rcParams ['font.family' ] = 'serif'
431+
432+ api_results = _add_used_vars_to_results (api_results , api_requirements )
433+
434+ # Prepare the data for HyperNetX visualization
435+ hyperedges = {}
436+ for eq , details in api_results ["results" ].items ():
437+ hyperedges [_get_latex_string_format (
438+ details ["latex_equation" ])] = details ["used_vars" ]
439+
440+ # Create the hypergraph using HyperNetX
441+ H = hnx .Hypergraph (hyperedges )
442+
443+ # Plot the hypergraph with enhanced clarity
444+ plt .figure (figsize = (16 , 12 ))
445+
446+ # Draw the hypergraph with node and edge labels
447+ hnx .draw (
448+ H ,
449+ with_edge_labels = True ,
450+ edge_labels_on_edge = False ,
451+ node_labels_kwargs = {'fontsize' : 14 },
452+ edge_labels_kwargs = {'fontsize' : 14 },
453+ layout_kwargs = {'seed' : 42 , 'scale' : 2.5 }
454+ )
455+
456+ plt .title (r"Enhanced Hypergraph of Equations and Variables" , fontsize = 20 )
457+ plt .show ()
458+
459+
0 commit comments