Skip to content

Commit cdeb2a6

Browse files
committed
Remove int_visualization_tool parameter and keep just plot_type. Modify the plot code generation for streamlit and quarto reports accordingly.
1 parent 06eddac commit cdeb2a6

File tree

6 files changed

+89
-113
lines changed

6 files changed

+89
-113
lines changed

report_metadata_micw2graph.yaml

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ report:
1111
and potential ecological functions.
1212
graphical_abstract: "example_data/MicW2Graph/Methods_MicW2Graph.png"
1313
logo: "example_data/mona_logo.png"
14-
report_type: "streamlit"
15-
report_format: ""
14+
report_type: "document"
15+
report_format: "docx"
1616
sections:
1717
- title: "Exploratory Data Analysis"
1818
subsections:
@@ -21,15 +21,13 @@ sections:
2121
components:
2222
- title: "Top 5 species by biome (plotly)"
2323
component_type: "plot"
24-
plot_type: "interactive"
24+
plot_type: "plotly"
2525
file_path: "example_data/MicW2Graph/top_species_plot_biome.json"
26-
int_visualization_tool: "plotly"
2726
caption: "Optional caption"
2827
- title: "Multiline plot (altair)"
2928
component_type: "plot"
30-
plot_type: "interactive"
29+
plot_type: "altair"
3130
file_path: "example_data/altair_multilineplot.json"
32-
int_visualization_tool: "altair"
3331
- title: "Abundance data for all studies (csv)"
3432
component_type: "dataframe"
3533
file_path: "example_data/MicW2Graph/abundance_data_allbiomes.csv"
@@ -48,9 +46,8 @@ sections:
4846
file_path: "example_data/MicW2Graph/number_samples_per_study.png"
4947
- title: "Sampling countries for all studies (plotly)"
5048
component_type: "plot"
51-
plot_type: "interactive"
49+
plot_type: "plotly"
5250
file_path: "example_data/MicW2Graph/pie_plot_countries.json"
53-
int_visualization_tool: "plotly"
5451
- title: "Sample data for all studies (txt)"
5552
component_type: "dataframe"
5653
file_path: "example_data/MicW2Graph/sample_info_allbiomes.txt"
@@ -70,15 +67,13 @@ sections:
7067
components:
7168
- title: "Network1 (graphml)"
7269
component_type: "plot"
73-
plot_type: "interactive"
74-
int_visualization_tool: "pyvis"
70+
plot_type: "interactive_network"
7571
file_path: "example_data/MicW2Graph/man_example.graphml"
7672
- title: "Network Visualization2"
7773
components:
7874
- title: "Network2 (edge list csv)"
7975
component_type: "plot"
80-
plot_type: "interactive"
81-
int_visualization_tool: "pyvis"
76+
plot_type: "interactive_network"
8277
csv_network_format: "edgelist"
8378
file_path: "example_data/MicW2Graph/man_example.csv"
8479
- title: "Edge list"

vuegen/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
report_config = load_yaml_config(config_path)
88

99
# Define logger suffix based on report engine, type and name
10-
report_engine = "streamlit"
10+
report_engine = "quarto"
1111
report_type = report_config['report'].get('report_type')
1212
report_format = report_config['report'].get('report_format')
1313
report_name = report_config['report'].get('name')

vuegen/metadata_manager.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,6 @@ def _create_plot_component(self, component_data: dict) -> r.Plot:
156156
"""
157157
# Validate enum fields
158158
plot_type = assert_enum_value(r.PlotType, component_data['plot_type'], self.logger)
159-
int_visualization_tool = (assert_enum_value(r.IntVisualizationTool, component_data.get('int_visualization_tool', ''), self.logger)
160-
if component_data.get('int_visualization_tool') else None)
161159
csv_network_format = (assert_enum_value(r.CSVNetworkFormat, component_data.get('csv_network_format', ''), self.logger)
162160
if component_data.get('csv_network_format') else None)
163161

@@ -166,7 +164,6 @@ def _create_plot_component(self, component_data: dict) -> r.Plot:
166164
logger = self.logger,
167165
file_path = component_data['file_path'],
168166
plot_type = plot_type,
169-
int_visualization_tool = int_visualization_tool,
170167
csv_network_format = csv_network_format,
171168
caption = component_data.get('caption')
172169
)

vuegen/quarto_reportview.py

Lines changed: 48 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -265,54 +265,50 @@ def _generate_plot_content(self, plot, is_report_static, static_dir: str = STATI
265265
plot_content = []
266266
# Add title
267267
plot_content.append(f'### {plot.title}')
268+
269+
# Define plot path
270+
if is_report_static:
271+
static_plot_path = os.path.join(static_dir, f"{plot.title.replace(' ', '_')}.png")
272+
else:
273+
html_plot_file = os.path.join(static_dir, f"{plot.title.replace(' ', '_')}.html")
268274

269-
if plot.plot_type == r.PlotType.INTERACTIVE:
270-
try:
271-
# Define plot path
275+
# Add content for the different plot types
276+
try:
277+
if plot.plot_type == r.PlotType.STATIC:
278+
plot_content.append(self._generate_image_content(plot.file_path, width=950))
279+
elif plot.plot_type == r.PlotType.PLOTLY:
280+
plot_content.append(self._generate_plot_code(plot))
272281
if is_report_static:
273-
static_plot_path = os.path.join(static_dir, f"{plot.title.replace(' ', '_')}.png")
282+
plot_content.append(f"""fig_plotly.write_image("{os.path.join("..", static_plot_path)}")\n```\n""")
283+
plot_content.append(self._generate_image_content(static_plot_path))
274284
else:
275-
html_plot_file = os.path.join(static_dir, f"{plot.title.replace(' ', '_')}.html")
276-
277-
if plot.int_visualization_tool == r.IntVisualizationTool.PLOTLY:
278-
plot_content.append(self._generate_plot_code(plot))
279-
if is_report_static:
280-
plot_content.append(f"""fig_plotly.write_image("{os.path.join("..", static_plot_path)}")\n```\n""")
281-
plot_content.append(self._generate_image_content(static_plot_path))
282-
else:
283-
plot_content.append(f"""fig_plotly.show()\n```\n""")
284-
elif plot.int_visualization_tool == r.IntVisualizationTool.ALTAIR:
285-
plot_content.append(self._generate_plot_code(plot))
286-
if is_report_static:
287-
plot_content.append(f"""fig_altair.save("{os.path.join("..", static_plot_path)}")\n```\n""")
288-
plot_content.append(self._generate_image_content(static_plot_path))
289-
else:
290-
plot_content.append(f"""fig_altair\n```\n""")
291-
elif plot.int_visualization_tool == r.IntVisualizationTool.PYVIS:
292-
G = plot.read_network()
293-
num_nodes = G.number_of_nodes()
294-
num_edges = G.number_of_edges()
295-
plot_content.append(f'**Number of nodes:** {num_nodes}\n')
296-
plot_content.append(f'**Number of edges:** {num_edges}\n')
297-
if is_report_static:
298-
plot.save_netwrok_image(G, static_plot_path, "png")
299-
plot_content.append(self._generate_image_content(static_plot_path))
300-
else:
301-
# Get the Network object
302-
net = plot.create_and_save_pyvis_network(G, html_plot_file)
303-
plot_content.append(self._generate_plot_code(plot, html_plot_file))
285+
plot_content.append(f"""fig_plotly.show()\n```\n""")
286+
elif plot.plot_type == r.PlotType.ALTAIR:
287+
plot_content.append(self._generate_plot_code(plot))
288+
if is_report_static:
289+
plot_content.append(f"""fig_altair.save("{os.path.join("..", static_plot_path)}")\n```\n""")
290+
plot_content.append(self._generate_image_content(static_plot_path))
304291
else:
305-
self.report.logger.warning(f"Unsupported interactive plot tool: {plot.int_visualization_tool}")
306-
except Exception as e:
307-
self.report.logger.error(f"Error generating interactive plot content for {plot.title}: {str(e)}")
308-
raise
309-
310-
elif plot.plot_type == r.PlotType.STATIC:
311-
try:
312-
plot_content.append(self._generate_image_content(plot.file_path, width=950))
313-
except Exception as e:
314-
self.report.logger.error(f"Error generating static plot content for {plot.title}: {str(e)}")
315-
raise
292+
plot_content.append(f"""fig_altair\n```\n""")
293+
elif plot.plot_type == r.PlotType.INTERACTIVE_NETWORK:
294+
G = plot.read_network()
295+
num_nodes = G.number_of_nodes()
296+
num_edges = G.number_of_edges()
297+
plot_content.append(f'**Number of nodes:** {num_nodes}\n')
298+
plot_content.append(f'**Number of edges:** {num_edges}\n')
299+
300+
if is_report_static:
301+
plot.save_netwrok_image(G, static_plot_path, "png")
302+
plot_content.append(self._generate_image_content(static_plot_path))
303+
else:
304+
# Get the Network object
305+
net = plot.create_and_save_pyvis_network(G, html_plot_file)
306+
plot_content.append(self._generate_plot_code(plot, html_plot_file))
307+
else:
308+
self.report.logger.warning(f"Unsupported plot type: {plot.plot_type}")
309+
except Exception as e:
310+
self.report.logger.error(f"Error generating content for '{plot.plot_type}' plot '{plot.id}' '{plot.title}': {str(e)}")
311+
raise
316312

317313
# Add caption if available
318314
if plot.caption:
@@ -344,13 +340,13 @@ def _generate_plot_code(self, plot, output_file = "") -> str:
344340
plot_data = plot_file.read()
345341
"""
346342
# Add specific code for each visualization tool
347-
if plot.int_visualization_tool == r.IntVisualizationTool.PLOTLY:
343+
if plot.plot_type == r.PlotType.PLOTLY:
348344
plot_code += """fig_plotly = pio.from_json(plot_data)
349345
fig_plotly.update_layout(width=950, height=500)
350346
"""
351-
elif plot.int_visualization_tool == r.IntVisualizationTool.ALTAIR:
347+
elif plot.plot_type == r.PlotType.ALTAIR:
352348
plot_code += """fig_altair = alt.Chart.from_json(plot_data).properties(width=900, height=400)"""
353-
elif plot.int_visualization_tool == r.IntVisualizationTool.PYVIS:
349+
elif plot.plot_type == r.PlotType.INTERACTIVE_NETWORK:
354350
plot_code = f"""<div style="text-align: center;">
355351
<iframe src="{os.path.join("..", output_file)}" alt="{plot.title} plot" width="800px" height="630px"></iframe>
356352
</div>\n"""
@@ -524,8 +520,8 @@ def _generate_component_imports(self, component: r.Component) -> List[str]:
524520
# Dictionary to hold the imports for each component type
525521
components_imports = {
526522
'plot': {
527-
r.IntVisualizationTool.ALTAIR: ['import altair as alt'],
528-
r.IntVisualizationTool.PLOTLY: ['import plotly.io as pio']
523+
r.PlotType.ALTAIR: ['import altair as alt'],
524+
r.PlotType.PLOTLY: ['import plotly.io as pio']
529525
},
530526
'dataframe': ['import pandas as pd', 'from itables import show', 'import dataframe_image as dfi'],
531527
'markdown': ['import IPython.display as display']
@@ -537,9 +533,9 @@ def _generate_component_imports(self, component: r.Component) -> List[str]:
537533

538534
# Add relevant imports based on component type and visualization tool
539535
if component_type == r.ComponentType.PLOT:
540-
int_visualization_tool = getattr(component, 'int_visualization_tool', None)
541-
if int_visualization_tool in components_imports['plot']:
542-
component_imports.extend(components_imports['plot'][int_visualization_tool])
536+
plot_type = getattr(component, 'plot_type', None)
537+
if plot_type in components_imports['plot']:
538+
component_imports.extend(components_imports['plot'][plot_type])
543539
elif component_type == r.ComponentType.DATAFRAME:
544540
component_imports.extend(components_imports['dataframe'])
545541
elif component_type == r.ComponentType.MARKDOWN:

vuegen/report.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,10 @@ class ComponentType(StrEnum):
2626
CHATBOT = auto()
2727

2828
class PlotType(StrEnum):
29-
INTERACTIVE = auto()
3029
STATIC = auto()
31-
32-
class IntVisualizationTool(StrEnum):
3330
PLOTLY = auto()
3431
ALTAIR = auto()
35-
PYVIS = auto()
32+
INTERACTIVE_NETWORK = auto()
3633

3734
class NetworkFormat(StrEnum):
3835
GML = auto()
@@ -110,8 +107,7 @@ class Plot(Component):
110107
The format of the CSV file for network plots (EDGELIST or ADJLIST) (default is None).
111108
"""
112109
def __init__(self, title: str, logger: logging.Logger, plot_type: PlotType, file_path: str=None,
113-
caption: str=None, int_visualization_tool: Optional[IntVisualizationTool]=None,
114-
csv_network_format: Optional[CSVNetworkFormat]=None):
110+
caption: str=None, csv_network_format: Optional[CSVNetworkFormat]=None):
115111
"""
116112
Initializes a Plot object.
117113
"""
@@ -121,7 +117,6 @@ def __init__(self, title: str, logger: logging.Logger, plot_type: PlotType, file
121117

122118
# Set specific attributes for the Plot class
123119
self.plot_type = plot_type
124-
self.int_visualization_tool = int_visualization_tool
125120
self.csv_network_format = csv_network_format
126121

127122
def read_network(self) -> nx.Graph:

vuegen/streamlit_reportview.py

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -303,39 +303,33 @@ def _generate_plot_content(self, plot, static_dir: str = STATIC_FILES_DIR) -> Li
303303

304304
# Add title
305305
plot_content.append(self._format_text(text=plot.title, type='header', level=4, color='#2b8cbe'))
306-
307-
if plot.plot_type == r.PlotType.INTERACTIVE:
308-
try:
309-
# Handle interactive plot
310-
if plot.int_visualization_tool == r.IntVisualizationTool.PLOTLY:
311-
plot_content.append(self._generate_plot_code(plot))
312-
elif plot.int_visualization_tool == r.IntVisualizationTool.ALTAIR:
313-
plot_content.append(self._generate_plot_code(plot))
314-
elif plot.int_visualization_tool == r.IntVisualizationTool.PYVIS:
315-
# For PyVis, handle the network visualization
316-
G = plot.read_network()
317-
html_plot_file = os.path.join(static_dir, f"{plot.title.replace(' ', '_')}.html")
318-
net = plot.create_and_save_pyvis_network(G, html_plot_file)
319-
num_nodes = len(net.nodes)
320-
num_edges = len(net.edges)
321-
plot_content.append(f"""with open('{html_plot_file}', 'r') as f:
306+
307+
# Add content for the different plot types
308+
try:
309+
if plot.plot_type == r.PlotType.STATIC:
310+
plot_content.append(f"\nst.image('{plot.file_path}', caption='{plot.caption}', use_column_width=True)\n")
311+
elif plot.plot_type == r.PlotType.PLOTLY:
312+
plot_content.append(self._generate_plot_code(plot))
313+
elif plot.plot_type == r.PlotType.ALTAIR:
314+
plot_content.append(self._generate_plot_code(plot))
315+
elif plot.plot_type == r.PlotType.INTERACTIVE_NETWORK:
316+
# Handle the network visualization
317+
G = plot.read_network()
318+
html_plot_file = os.path.join(static_dir, f"{plot.title.replace(' ', '_')}.html")
319+
net = plot.create_and_save_pyvis_network(G, html_plot_file)
320+
num_nodes = len(net.nodes)
321+
num_edges = len(net.edges)
322+
plot_content.append(f"""with open('{html_plot_file}', 'r') as f:
322323
html_data = f.read()
323324
st.markdown(f"<p style='text-align: center; color: black;'> <b>Number of nodes:</b> {num_nodes} </p>", unsafe_allow_html=True)
324325
st.markdown(f"<p style='text-align: center; color: black;'> <b>Number of relationships:</b> {num_edges} </p>", unsafe_allow_html=True)""")
325-
plot_content.append(self._generate_plot_code(plot))
326-
else:
327-
self.report.logger.warning(f"Unsupported interactive plot tool: {plot.int_visualization_tool}")
328-
except Exception as e:
329-
self.report.logger.error(f"Error generating interactive content for plot '{plot.id}' '{plot.title}': {str(e)}")
330-
raise
326+
plot_content.append(self._generate_plot_code(plot))
327+
else:
328+
self.report.logger.warning(f"Unsupported plot type: {plot.plot_type}")
329+
except Exception as e:
330+
self.report.logger.error(f"Error generating content for '{plot.plot_type}' plot '{plot.id}' '{plot.title}': {str(e)}")
331+
raise
331332

332-
elif plot.plot_type == r.PlotType.STATIC:
333-
try:
334-
# Handle static plot
335-
plot_content.append(f"\nst.image('{plot.file_path}', caption='{plot.caption}', use_column_width=True)\n")
336-
except Exception as e:
337-
self.report.logger.error(f"Error generating content for static plot '{plot.id}' '{plot.title}': {str(e)}")
338-
raise
339333
# Add caption if available
340334
if plot.caption:
341335
plot_content.append(self._format_text(text=plot.caption, type='caption', text_align="left"))
@@ -362,20 +356,19 @@ def _generate_plot_code(self, plot) -> str:
362356
plot_code = f"""with open('{plot.file_path}', 'r') as plot_file:
363357
plot_json = json.load(plot_file)\n"""
364358
# Add specific code for each visualization tool
365-
if plot.int_visualization_tool == r.IntVisualizationTool.PLOTLY:
359+
if plot.plot_type == r.PlotType.PLOTLY:
366360
plot_code += "st.plotly_chart(plot_json, use_container_width=True)\n"
367361

368-
elif plot.int_visualization_tool == r.IntVisualizationTool.ALTAIR:
362+
elif plot.plot_type == r.PlotType.ALTAIR:
369363
plot_code += """altair_plot = alt.Chart.from_dict(plot_json)
370364
st.vega_lite_chart(json.loads(altair_plot.to_json()), use_container_width=True)\n"""
371365

372-
elif plot.int_visualization_tool == r.IntVisualizationTool.PYVIS:
366+
elif plot.plot_type == r.PlotType.INTERACTIVE_NETWORK:
373367
plot_code = """# Streamlit checkbox for controlling the layout
374368
control_layout = st.checkbox('Add panel to control layout', value=True)
375369
net_html_height = 1200 if control_layout else 630
376370
# Load HTML into HTML component for display on Streamlit
377371
st.components.v1.html(html_data, height=net_html_height)\n"""
378-
379372
return plot_code
380373

381374
def _generate_dataframe_content(self, dataframe) -> List[str]:
@@ -537,8 +530,8 @@ def _generate_component_imports(self, component: r.Component) -> List[str]:
537530
# Dictionary to hold the imports for each component type
538531
components_imports = {
539532
'plot': {
540-
r.IntVisualizationTool.ALTAIR: ['import json', 'import altair as alt'],
541-
r.IntVisualizationTool.PLOTLY: ['import json']
533+
r.PlotType.ALTAIR: ['import json', 'import altair as alt'],
534+
r.PlotType.PLOTLY: ['import json']
542535
},
543536
'dataframe': ['import pandas as pd']
544537
}
@@ -548,9 +541,9 @@ def _generate_component_imports(self, component: r.Component) -> List[str]:
548541

549542
# Add relevant imports based on component type and visualization tool
550543
if component_type == r.ComponentType.PLOT:
551-
int_visualization_tool = getattr(component, 'int_visualization_tool', None)
552-
if int_visualization_tool in components_imports['plot']:
553-
component_imports.extend(components_imports['plot'][int_visualization_tool])
544+
plot_type = getattr(component, 'plot_type', None)
545+
if plot_type in components_imports['plot']:
546+
component_imports.extend(components_imports['plot'][plot_type])
554547
elif component_type == r.ComponentType.DATAFRAME:
555548
component_imports.extend(components_imports['dataframe'])
556549

0 commit comments

Comments
 (0)