|
1 | 1 | import ipywidgets as widgets # type: ignore |
2 | | -from IPython.display import display, Math # type: ignore |
| 2 | +from IPython.display import display, Math, HTML # type: ignore |
3 | 3 | from dataclasses import dataclass, field |
4 | 4 | import hypernetx as hnx |
5 | 5 | import matplotlib.pyplot as plt |
6 | 6 | import re |
7 | | -from ax_core.utils import printing |
8 | 7 |
|
9 | 8 | OPTION_LIST = { |
10 | 9 | "Select a template": [], |
@@ -373,17 +372,25 @@ def display_results(equations_dict): |
373 | 372 | rhs = value.get('rhs') |
374 | 373 | if not match: |
375 | 374 | not_match_counter += 1 |
376 | | - print( |
377 | | - printing.print_color.bold( |
378 | | - printing.print_color.red( |
379 | | - "Provided requirements DO NOT fulfill the following mathematical relation:" |
380 | | - ) |
381 | | - ) |
382 | | - ) |
| 375 | + display(HTML( |
| 376 | + '<p style="color:red; ' |
| 377 | + 'font-weight:bold; ' |
| 378 | + 'font-family:\'Times New Roman\'; ' |
| 379 | + 'font-size:16px;">' |
| 380 | + 'Provided requirements DO NOT fulfill the following mathematical relation:' |
| 381 | + '</p>' |
| 382 | + )) |
383 | 383 | display(Math(latex_equation)) |
384 | 384 | print(f"For provided values:\nleft hand side = {lhs}\nright hand side = {rhs}") |
385 | 385 | if not_match_counter == 0: |
386 | | - print(printing.print_color.green("Requirements you provided do not cause any conflicts")) |
| 386 | + display(HTML( |
| 387 | + '<p style="color:green; ' |
| 388 | + 'font-weight:bold; ' |
| 389 | + 'font-family:\'Times New Roman\'; ' |
| 390 | + 'font-size:16px;">' |
| 391 | + 'Requirements you provided do not cause any conflicts' |
| 392 | + '</p>' |
| 393 | + )) |
387 | 394 |
|
388 | 395 |
|
389 | 396 | def _get_latex_string_format(input_string): |
@@ -423,7 +430,6 @@ def _add_used_vars_to_results(api_results, api_requirements): |
423 | 430 |
|
424 | 431 |
|
425 | 432 | def get_eq_hypergraph(api_results, api_requirements): |
426 | | - |
427 | 433 | # Disable external LaTeX rendering, using matplotlib's mathtext instead |
428 | 434 | plt.rcParams['text.usetex'] = False |
429 | 435 | plt.rcParams['mathtext.fontset'] = 'stix' |
@@ -453,7 +459,36 @@ def get_eq_hypergraph(api_results, api_requirements): |
453 | 459 | layout_kwargs={'seed': 42, 'scale': 2.5} |
454 | 460 | ) |
455 | 461 |
|
| 462 | + node_labels = list(H.nodes) |
| 463 | + symbol_explanations = _get_node_names_for_node_lables(node_labels, api_requirements) |
| 464 | + |
| 465 | + # Adding the symbol explanations as a legend |
| 466 | + explanation_text = "\n".join([f"${symbol}$: {desc}" for symbol, desc in symbol_explanations]) |
| 467 | + plt.annotate( |
| 468 | + explanation_text, |
| 469 | + xy=(1.05, 0.5), |
| 470 | + xycoords='axes fraction', |
| 471 | + fontsize=14, |
| 472 | + verticalalignment='center' |
| 473 | + ) |
| 474 | + |
456 | 475 | plt.title(r"Enhanced Hypergraph of Equations and Variables", fontsize=20) |
457 | 476 | plt.show() |
458 | 477 |
|
459 | 478 |
|
| 479 | +def _get_node_names_for_node_lables(node_labels, api_requirements): |
| 480 | + |
| 481 | + # Create the output list |
| 482 | + node_names = [] |
| 483 | + |
| 484 | + # Iterate through each symbol in S |
| 485 | + for symbol in node_labels: |
| 486 | + # Search for the matching requirement |
| 487 | + symbol = symbol.replace("$", "") |
| 488 | + for req in api_requirements: |
| 489 | + if req['latex_symbol'] == symbol: |
| 490 | + # Add the matching tuple to SS |
| 491 | + node_names.append((req["latex_symbol"], req["requirement_name"])) |
| 492 | + break # Stop searching once a match is found |
| 493 | + |
| 494 | + return node_names |
0 commit comments