Skip to content

Commit 0eafe40

Browse files
committed
Add to_folder option for plots
1 parent ce91e0f commit 0eafe40

File tree

10 files changed

+33
-58
lines changed

10 files changed

+33
-58
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ async def main():
282282
epochs=epochs,
283283
)
284284

285-
filepath = synalinks.utils.plot_history(history)
285+
synalinks.utils.plot_history(history)
286286

287287
if __name__ == "__main__":
288288
asyncio.run(main())

coverage-badge.svg

Lines changed: 1 addition & 1 deletion
Loading

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ exclude_lines = [
9898
]
9999
omit = [
100100
"*/__init__.py",
101+
"*/plot_*.py",
102+
"*/*_visualization.py",
101103
]
102104

103105
[tool.coverage.run]

synalinks/src/programs/program.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,6 @@ def summary(
231231
line_length=None,
232232
positions=None,
233233
print_fn=None,
234-
return_string=False,
235234
expand_nested=False,
236235
show_trainable=False,
237236
module_range=None,
@@ -250,7 +249,6 @@ def summary(
250249
It will be called on each line of the summary.
251250
You can set it to a custom function
252251
in order to capture the string summary.
253-
return_string: If True, return the summary string instead of printing.
254252
expand_nested: Whether to expand the nested models.
255253
Defaults to `False`.
256254
show_trainable: Whether to show if a module is trainable.
@@ -277,7 +275,6 @@ def summary(
277275
show_trainable=show_trainable,
278276
module_range=module_range,
279277
)
280-
# TODO capture stdout and return the string
281278

282279
def save(self, filepath, overwrite=True, **kwargs):
283280
"""Saves a program as a `.json` file.

synalinks/src/utils/plot_history.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# License Apache 2.0: (c) 2025 Yoan Sallami (Synalinks Team)
22

3+
import os
34
import matplotlib.pyplot as plt
45
from matplotlib.ticker import MaxNLocator
56

@@ -11,6 +12,7 @@
1112
def plot_history(
1213
history,
1314
to_file="training_history.png",
15+
to_folder=None,
1416
xlabel="Epochs",
1517
ylabel="Scores",
1618
title="Training history",
@@ -47,8 +49,10 @@ def plot_history(
4749
ValueError: If there are unrecognized keyword arguments.
4850
4951
Returns:
50-
(IPython.display.Image): If running in a Jupyter notebook,
51-
returns an IPython Image object for inline display.
52+
(IPython.display.Image | marimo.Image | str):
53+
If running in a Jupyter notebook, returns an IPython Image object
54+
for inline display. If running in a Marimo notebook returns a marimo image.
55+
Otherwise returns the filepath where the image have been saved.
5256
"""
5357

5458
colors = generate_distinct_colors(len(history.history))
@@ -68,6 +72,8 @@ def plot_history(
6872
plt.ylim(0.0, 1.0)
6973
plt.legend()
7074
plt.grid(grid)
75+
if to_folder:
76+
to_file = os.path.join(to_folder, to_file)
7177
plt.savefig(to_file)
7278
plt.close()
7379
try:

synalinks/src/utils/plot_metrics.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# License Apache 2.0: (c) 2025 Yoan Sallami (Synalinks Team)
22

3+
import os
34
import matplotlib.pyplot as plt
45

56
from synalinks.src.api_export import synalinks_export
@@ -10,6 +11,7 @@
1011
def plot_metrics(
1112
metrics,
1213
to_file="evaluation_metrics.png",
14+
to_folder=None,
1315
xlabel="Metrics",
1416
ylabel="Scores",
1517
title="Evaluation metrics",
@@ -42,8 +44,10 @@ def plot_metrics(
4244
ValueError: If there are unrecognized keyword arguments.
4345
4446
Returns:
45-
(IPython.display.Image): If running in a Jupyter notebook,
46-
returns an IPython Image object for inline display.
47+
(IPython.display.Image | marimo.Image | str):
48+
If running in a Jupyter notebook, returns an IPython Image object
49+
for inline display. If running in a Marimo notebook returns a marimo image.
50+
Otherwise returns the filepath where the image have been saved.
4751
"""
4852

4953
metric_names = list(metrics.keys())
@@ -60,6 +64,8 @@ def plot_metrics(
6064
plt.ylim(0.0, 1.0)
6165
plt.legend()
6266
plt.grid(grid)
67+
if to_folder:
68+
to_file = os.path.join(to_folder, to_file)
6369
plt.savefig(to_file)
6470
plt.close()
6571
try:

synalinks/src/utils/program_visualization.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ def program_to_dot(
331331
def plot_program(
332332
program,
333333
to_file=None,
334+
to_folder=None,
334335
show_schemas=False,
335336
show_module_names=False,
336337
rankdir="TB",
@@ -351,10 +352,10 @@ def plot_program(
351352
outputs=outputs,
352353
)
353354
354-
dot_img_file = '/tmp/program_1.png'
355355
synalinks.utils.plot_program(
356356
program,
357-
to_file=dot_img_file,
357+
to_file="program_1.png",
358+
to_folder="/tmp",
358359
show_schemas=True,
359360
show_trainable=True,
360361
)
@@ -378,8 +379,10 @@ def plot_program(
378379
show_trainable (bool): whether to display if a module is trainable.
379380
380381
Returns:
381-
(IPython.display.Image): A Jupyter notebook Image object if Jupyter is installed.
382-
This enables in-line display of the program plots in notebooks.
382+
(IPython.display.Image | marimo.Image | str):
383+
If running in a Jupyter notebook, returns an IPython Image object
384+
for inline display. If running in a Marimo notebook returns a marimo image.
385+
Otherwise returns the filepath where the image have been saved.
383386
"""
384387

385388
if not to_file:
@@ -433,15 +436,17 @@ def plot_program(
433436
return
434437
dot = remove_unused_edges(dot)
435438
_, extension = os.path.splitext(to_file)
439+
if to_folder:
440+
to_file = os.path.join(to_folder, to_file)
436441
if not extension:
437442
extension = "png"
438443
else:
439444
extension = extension[1:]
440445
# Save image to disk.
441446
dot.write(to_file, format=extension)
442-
# Return the image as a Jupyter Image object, to be displayed in-line.
447+
# Return the image as a Jupyter Image object or Marimo Image object, to be displayed in-line.
443448
# Note that we cannot easily detect whether the code is running in a
444-
# notebook, and thus we always return the Image if Jupyter is available.
449+
# Jupyter notebook, and thus we always return the Image if Jupyter is available.
445450
if extension != "pdf":
446451
try:
447452
from IPython import display

synalinks/src/utils/summary_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def print_module(module, nested_level=0):
286286
console = rich.console.Console(highlight=False)
287287

288288
# Print the to the console.
289-
console.print(bold_text(f"Program: {rich.markup.escape(program.name)}"))
289+
console.print(f"Program: {rich.markup.escape(program.name)}")
290290
console.print(f"description: '{rich.markup.escape(program.description)}'")
291291
console.print(table)
292292

synalinks/src/utils/summary_utils_test.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

synalinks/src/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from synalinks.src.api_export import synalinks_export
44

55
# Unique source of truth for the version number.
6-
__version__ = "0.1.3001"
6+
__version__ = "0.1.3002"
77

88

99
@synalinks_export("synalinks.version")

0 commit comments

Comments
 (0)