Skip to content

Commit ba197f9

Browse files
committed
Refactor the StreamlitReportView class by splitting the _generate_sections method into smaller methods, similar as for QuartoReportView
1 parent 79df207 commit ba197f9

File tree

3 files changed

+110
-119
lines changed

3 files changed

+110
-119
lines changed

report/quarto_reportview.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def generate_report(self, output_dir: str = 'quarto_report/') -> None:
5050
# Define the YAML header for the quarto report
5151
yaml_header = self._create_yaml_header()
5252

53-
# Create qmd content for the report
53+
# Create qmd content and imports for the report
5454
qmd_content = []
5555
report_imports = []
5656

@@ -274,7 +274,7 @@ def _generate_plot_content(self, plot, is_report_static) -> List[str]:
274274

275275
def _generate_plot_code(self, plot, output_file = "") -> str:
276276
"""
277-
Create the code template for a plot based on its visualization tool.
277+
Create the plot code based on its visualization tool.
278278
279279
Parameters
280280
----------
@@ -285,27 +285,27 @@ def _generate_plot_code(self, plot, output_file = "") -> str:
285285
Returns
286286
-------
287287
str
288-
The generated code template as a string.
288+
The generated plot code as a string.
289289
"""
290290
# Start with the common data loading code
291-
template = f"""```{{python}}
291+
plot_code = f"""```{{python}}
292292
#| label: {plot.name}
293293
#| echo: false
294294
with open('{os.path.join("..", plot.file_path)}', 'r') as plot_file:
295295
plot_data = plot_file.read()
296296
"""
297297
# Add specific code for each visualization tool
298298
if plot.visualization_tool == r.VisualizationTool.PLOTLY:
299-
template += """fig_plotly = pio.from_json(plot_data)
299+
plot_code += """fig_plotly = pio.from_json(plot_data)
300300
fig_plotly.update_layout(width=950, height=500)
301301
"""
302302
elif plot.visualization_tool == r.VisualizationTool.ALTAIR:
303-
template += """fig_altair = alt.Chart.from_json(plot_data).properties(width=900, height=400)"""
303+
plot_code += """fig_altair = alt.Chart.from_json(plot_data).properties(width=900, height=400)"""
304304
elif plot.visualization_tool == r.VisualizationTool.PYVIS:
305-
template = f"""<div style="text-align: center;">
305+
plot_code = f"""<div style="text-align: center;">
306306
<iframe src="{os.path.join("..", output_file)}" alt="{plot.name} plot" width="800px" height="630px"></iframe>
307307
</div>\n"""
308-
return template
308+
return plot_code
309309

310310
def _generate_dataframe_content(self, dataframe, is_report_static) -> List[str]:
311311
"""

report/report.py

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,6 @@ class Component(ABC):
6464
component_type: ComponentType
6565
title: Optional[str] = None
6666
caption: Optional[str] = None
67-
68-
@abstractmethod
69-
def generate_imports(self) -> str:
70-
"""
71-
Generate the import statements required for the component.
72-
73-
Returns
74-
-------
75-
str
76-
A string representing the import statements needed for the component.
77-
"""
78-
pass
7967

8068
class Plot(Component):
8169
"""
@@ -104,23 +92,6 @@ def __init__(self, identifier: int, name: str, file_path: str, plot_type: PlotTy
10492
self.plot_type = plot_type
10593
self.visualization_tool = visualization_tool
10694
self.csv_network_format = csv_network_format
107-
108-
def generate_imports(self) -> str:
109-
"""
110-
Generate the import statements required for the visualization tool.
111-
112-
Returns
113-
-------
114-
str
115-
A string representing the import statements needed for the plot.
116-
"""
117-
imports = []
118-
imports.append('import json')
119-
if self.visualization_tool == VisualizationTool.ALTAIR:
120-
imports.append('import altair as alt')
121-
elif self.visualization_tool == VisualizationTool.PLOTLY:
122-
imports.append('import plotly.io as pio')
123-
return "\n".join(imports)
12495

12596
def read_network(self) -> nx.Graph:
12697
"""
@@ -294,35 +265,12 @@ def __init__(self, identifier: int, name: str, file_path: str, file_format: Data
294265
self.file_format = file_format
295266
self.delimiter = delimiter
296267

297-
def generate_imports(self) -> str:
298-
"""
299-
Generate the import statements required for handling DataFrames.
300-
301-
Returns
302-
-------
303-
str
304-
A string representing the import statements needed for the DataFrame.
305-
"""
306-
return "import pandas as pd\nfrom itables import show\nimport dataframe_image as dfi"
307-
308-
309268
@dataclass
310269
class Markdown(Component):
311270
component_type = ComponentType.MARKDOWN
312271
"""
313272
A Markdown text component within a subsection of a report.
314273
"""
315-
316-
def generate_imports(self) -> str:
317-
"""
318-
Generate the import statements required for Markdown rendering.
319-
320-
Returns
321-
-------
322-
str
323-
A string representing the import statements needed for rendering Markdown.
324-
"""
325-
return "import IPython.display as display"
326274

327275
@dataclass
328276
class Subsection:

report/streamlit_reportview.py

Lines changed: 102 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,11 @@ def generate_report(self, output_dir: str = 'streamlit_report/sections') -> None
4141
os.makedirs(output_dir, exist_ok=True)
4242

4343
# Define the Streamlit imports and report manager content
44-
streamlit_imports = f'''import streamlit as st
44+
report_manag_content = []
45+
report_manag_content.append(f"""import streamlit as st\n
4546
st.set_page_config(layout="wide", page_title="{self.report.name}", page_icon="{self.report.logo}")
46-
st.logo("{self.report.logo}")
47-
'''
48-
general_head = self._format_text(text=self.report.title, type = 'header', level=1, color='#023858')
49-
report_manag_content = [streamlit_imports, general_head]
47+
st.logo("{self.report.logo}")""")
48+
report_manag_content.append(self._format_text(text=self.report.title, type = 'header', level=1, color='#023858'))
5049

5150
# Initialize a dictionary to store the navigation structure
5251
report_manag_content.append("\nsections_pages = {}")
@@ -73,8 +72,8 @@ def generate_report(self, output_dir: str = 'streamlit_report/sections') -> None
7372
report_manag_content.append(f"sections_pages['{section.name}'] = [{', '.join(subsection_page_vars)}]\n")
7473

7574
# Add navigation object to the home page content
76-
report_manag_content.append(f'''report_nav = st.navigation(sections_pages)
77-
report_nav.run()''')
75+
report_manag_content.append(f"""report_nav = st.navigation(sections_pages)
76+
report_nav.run()""")
7877

7978
# Write the navigation and general content to a Python file
8079
with open(os.path.join(output_dir, "report_manager.py"), 'w') as nav_manager:
@@ -216,7 +215,7 @@ def _generate_subsection(self, subsection) -> List[str]:
216215
"""
217216
subsection_content = []
218217
subsection_imports = []
219-
218+
220219
# Add subsection header and description
221220
subsection_content.append(self._format_text(text=subsection.name, type='header', level=3, color='#023558'))
222221
subsection_content.append(self._format_text(text=subsection.description, type='paragraph'))
@@ -228,42 +227,11 @@ def _generate_subsection(self, subsection) -> List[str]:
228227

229228
# Handle different types of components
230229
if component.component_type == r.ComponentType.PLOT:
231-
# Cast component to Plot
232-
plot = component
233-
subsection_content.extend(self._generate_plot_content(plot))
234-
230+
subsection_content.extend(self._generate_plot_content(component))
235231
elif component.component_type == r.ComponentType.DATAFRAME:
236-
# Cast component to DataFrame
237-
dataframe = component
238-
if dataframe.file_format == r.DataFrameFormat.CSV:
239-
subsection_content.append(self._format_text(text=dataframe.title, type='header', level=4, color='#2b8cbe'))
240-
if dataframe.delimiter:
241-
subsection_content.append(f"""df = pd.read_csv('{dataframe.file_path}', delimiter='{dataframe.delimiter}')
242-
st.dataframe(df, use_container_width=True)\n""")
243-
else:
244-
subsection_content.append(f"""df = pd.read_csv('{dataframe.file_path}')
245-
st.dataframe(df, use_container_width=True)\n""")
246-
elif dataframe.file_format == r.DataFrameFormat.PARQUET:
247-
subsection_content.append(self._format_text(text=dataframe.title, type='header', level=4, color='#2b8cbe'))
248-
subsection_content.append(f"""df = pd.read_parquet('{dataframe.file_path}')
249-
st.dataframe(df, use_container_width=True)\n""")
250-
elif dataframe.file_format == r.DataFrameFormat.TXT:
251-
subsection_content.append(self._format_text(text=dataframe.title, type='header', level=4, color='#2b8cbe'))
252-
subsection_content.append(f"""df = pd.read_csv('{dataframe.file_path}', sep='\\t')
253-
st.dataframe(df, use_container_width=True)\n""")
254-
elif dataframe.file_format == r.DataFrameFormat.EXCEL:
255-
subsection_content.append(self._format_text(text=dataframe.title, type='header', level=4, color='#2b8cbe'))
256-
subsection_content.append(f"""df = pd.read_excel('{dataframe.file_path}')
257-
st.dataframe(df, use_container_width=True)\n""")
258-
else:
259-
raise ValueError(f"Unsupported DataFrame file format: {dataframe.file_format}")
232+
subsection_content.extend(self._generate_dataframe_content(component))
260233
elif component.component_type == r.ComponentType.MARKDOWN:
261-
# Cast component to Markdown
262-
markdown = component
263-
subsection_content.append(self._format_text(text=markdown.title, type='header', level=4, color='#2b8cbe'))
264-
subsection_content.append(f"""with open('{markdown.file_path}', 'r') as markdown_file:
265-
markdown_content = markdown_file.read()
266-
st.markdown(markdown_content, unsafe_allow_html=True)\n""")
234+
subsection_content.extend(self._generate_markdown_content(component))
267235
return subsection_content, subsection_imports
268236

269237
def _generate_plot_content(self, plot) -> List[str]:
@@ -286,14 +254,9 @@ def _generate_plot_content(self, plot) -> List[str]:
286254
if plot.plot_type == r.PlotType.INTERACTIVE:
287255
# Handle interactive plot
288256
if plot.visualization_tool == r.VisualizationTool.PLOTLY:
289-
plot_content.append(f"""with open('{plot.file_path}', 'r') as plot_file:
290-
plot_json = json.load(plot_file)
291-
st.plotly_chart(plot_json, use_container_width=True)\n""")
257+
plot_content.append(self._generate_plot_code(plot))
292258
elif plot.visualization_tool == r.VisualizationTool.ALTAIR:
293-
plot_content.append(f"""with open('{plot.file_path}', 'r') as plot_file:
294-
plot_json = json.load(plot_file)
295-
altair_plot = alt.Chart.from_dict(plot_json)
296-
st.vega_lite_chart(json.loads(altair_plot.to_json()), use_container_width=True)\n""")
259+
plot_content.append(self._generate_plot_code(plot))
297260
elif plot.visualization_tool == r.VisualizationTool.PYVIS:
298261
# For PyVis, handle the network visualization
299262
G = plot.read_network()
@@ -303,22 +266,103 @@ def _generate_plot_content(self, plot) -> List[str]:
303266
num_edges = len(net.edges)
304267
plot_content.append(f"""with open('{html_plot_file}', 'r') as f:
305268
html_data = f.read()
306-
307269
st.markdown(f"<p style='text-align: center; color: black;'> <b>Number of nodes:</b> {num_nodes} </p>", unsafe_allow_html=True)
308-
st.markdown(f"<p style='text-align: center; color: black;'> <b>Number of relationships:</b> {num_edges} </p>", unsafe_allow_html=True)
309-
310-
# Streamlit checkbox for controlling the layout
311-
control_layout = st.checkbox('Add panel to control layout', value=True)
312-
net_html_height = 1200 if control_layout else 630
313-
314-
# Load HTML into HTML component for display on Streamlit
315-
st.components.v1.html(html_data, height=net_html_height)\n""")
270+
st.markdown(f"<p style='text-align: center; color: black;'> <b>Number of relationships:</b> {num_edges} </p>", unsafe_allow_html=True)""")
271+
plot_content.append(self._generate_plot_code(plot))
316272
elif plot.plot_type == r.PlotType.STATIC:
317273
# Handle static plot
318274
plot_content.append(f"\nst.image('{plot.file_path}', caption='{plot.caption}', use_column_width=True)\n")
319275

320276
return plot_content
321277

278+
def _generate_plot_code(self, plot) -> str:
279+
"""
280+
Create the plot code based on its visualization tool.
281+
282+
Parameters
283+
----------
284+
plot : Plot
285+
The plot component to generate the code template for.
286+
output_file: str, optional
287+
The output html file name to be displayed with a pyvis plot.
288+
Returns
289+
-------
290+
str
291+
The generated plot code as a string.
292+
"""
293+
# Start with the common data loading code
294+
plot_code = f"""with open('{plot.file_path}', 'r') as plot_file:
295+
plot_json = json.load(plot_file)\n"""
296+
# Add specific code for each visualization tool
297+
if plot.visualization_tool == r.VisualizationTool.PLOTLY:
298+
plot_code += "st.plotly_chart(plot_json, use_container_width=True)\n"
299+
300+
elif plot.visualization_tool == r.VisualizationTool.ALTAIR:
301+
plot_code += """altair_plot = alt.Chart.from_dict(plot_json)
302+
st.vega_lite_chart(json.loads(altair_plot.to_json()), use_container_width=True)\n"""
303+
304+
elif plot.visualization_tool == r.VisualizationTool.PYVIS:
305+
plot_code = """# Streamlit checkbox for controlling the layout
306+
control_layout = st.checkbox('Add panel to control layout', value=True)
307+
net_html_height = 1200 if control_layout else 630
308+
# Load HTML into HTML component for display on Streamlit
309+
st.components.v1.html(html_data, height=net_html_height)\n"""
310+
311+
return plot_code
312+
313+
def _generate_dataframe_content(self, dataframe) -> List[str]:
314+
"""
315+
Generate content for a DataFrame component.
316+
317+
Parameters
318+
----------
319+
dataframe : DataFrame
320+
The dataframe component to generate content for.
321+
322+
Returns
323+
-------
324+
list : List[str]
325+
The list of content lines for the DataFrame.
326+
"""
327+
dataframe_content = []
328+
dataframe_content.append(self._format_text(text=dataframe.title, type='header', level=4, color='#2b8cbe'))
329+
330+
if dataframe.file_format == r.DataFrameFormat.CSV:
331+
dataframe_content.append(f"df = pd.read_csv('{dataframe.file_path}')")
332+
elif dataframe.file_format == r.DataFrameFormat.PARQUET:
333+
dataframe_content.append(f"df = pd.read_parquet('{dataframe.file_path}')")
334+
elif dataframe.file_format == r.DataFrameFormat.TXT:
335+
dataframe_content.append(f"df = pd.read_csv('{dataframe.file_path}', sep='\\t')")
336+
elif dataframe.file_format == r.DataFrameFormat.EXCEL:
337+
dataframe_content.append(f"df = pd.read_excel('{dataframe.file_path}')")
338+
else:
339+
raise ValueError(f"Unsupported DataFrame file format: {dataframe.file_format}")
340+
341+
dataframe_content.append("st.dataframe(df, use_container_width=True)")
342+
343+
return dataframe_content
344+
345+
def _generate_markdown_content(self, markdown) -> List[str]:
346+
"""
347+
Generate content for a Markdown component.
348+
349+
Parameters
350+
----------
351+
markdown : Markdown
352+
The markdown component to generate content for.
353+
354+
Returns
355+
-------
356+
list : List[str]
357+
The list of content lines for the markdown.
358+
"""
359+
markdown_content = []
360+
markdown_content.append(self._format_text(text=markdown.title, type='header', level=4, color='#2b8cbe'))
361+
markdown_content.append(f"""with open('{markdown.file_path}', 'r') as markdown_file:
362+
markdown_content = markdown_file.read()
363+
st.markdown(markdown_content, unsafe_allow_html=True)\n""")
364+
365+
return markdown_content
322366

323367
def _generate_component_imports(self, component: r.Component) -> str:
324368
"""
@@ -344,7 +388,6 @@ def _generate_component_imports(self, component: r.Component) -> str:
344388
},
345389
'dataframe': 'import pandas as pd'
346390
}
347-
348391
# Iterate over sections and subsections to determine needed imports
349392
component_type = component.component_type
350393

0 commit comments

Comments
 (0)