Skip to content

Commit d62478e

Browse files
committed
Add json loader, saver
1 parent 08ec8dc commit d62478e

File tree

3 files changed

+180
-58
lines changed

3 files changed

+180
-58
lines changed

discorpy/losa/loadersaver.py

Lines changed: 144 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,14 @@
3131
3232
"""
3333

34+
import json
3435
import platform
3536
from pathlib import Path
3637
import h5py
3738
import numpy as np
3839
from PIL import Image
3940
import matplotlib.pyplot as plt
41+
from matplotlib import font_manager
4042
from collections import OrderedDict
4143

4244

@@ -422,8 +424,7 @@ def save_image(file_path, mat, overwrite=True):
422424
str
423425
Updated file path.
424426
"""
425-
file_path = __get_path(file_path, check_exist=False)
426-
file_path = file_path.resolve()
427+
file_path = __get_path(file_path, check_exist=False).resolve()
427428
file_ext = file_path.suffix
428429
if not ((file_ext == ".tif") or (file_ext == ".tiff")):
429430
if mat.dtype != np.uint8:
@@ -494,8 +495,29 @@ def save_plot_image(file_path, list_lines, height, width, overwrite=True,
494495
return file_path
495496

496497

498+
def __check_font(font_family):
499+
"""
500+
Check if a specific font is available in Matplotlib.
501+
502+
Parameters
503+
----------
504+
font_family : str
505+
Name of the font to check.
506+
507+
Returns
508+
-------
509+
bool
510+
True if font is available, False otherwise.
511+
"""
512+
try:
513+
font_manager.findfont(font_family, fallback_to_default=False)
514+
return True
515+
except:
516+
return False
517+
518+
497519
def save_residual_plot(file_path, list_data, height, width, overwrite=True,
498-
dpi=100):
520+
dpi=100, font_family='Times New Roman'):
499521
"""
500522
Save the plot of residual against radius to an image. Useful to check the
501523
accuracy of unwarping results.
@@ -514,6 +536,8 @@ def save_residual_plot(file_path, list_data, height, width, overwrite=True,
514536
Overwrite the existing file if True.
515537
dpi : int, optional
516538
The resolution in dots per inch.
539+
font_family : str, optional
540+
To set the font family
517541
518542
Returns
519543
-------
@@ -528,7 +552,8 @@ def save_residual_plot(file_path, list_data, height, width, overwrite=True,
528552
fig.set_size_inches(width / dpi, height / dpi)
529553
m_size = 0.5 * min(height / dpi, width / dpi)
530554
plt.rc('font', size=np.int16(m_size * 4))
531-
plt.rcParams['font.family'] = 'Times New Roman'
555+
if __check_font(font_family):
556+
plt.rcParams['font.family'] = font_family
532557
plt.rcParams['font.weight'] = 'bold'
533558
plt.xlabel('Radius', fontweight='bold')
534559
plt.ylabel('Residual', fontweight='bold')
@@ -562,8 +587,7 @@ def save_hdf_file(file_path, idata, key_path='entry', overwrite=True):
562587
str
563588
Updated file path.
564589
"""
565-
file_path = __get_path(file_path, check_exist=False)
566-
file_path = file_path.resolve()
590+
file_path = __get_path(file_path, check_exist=False).resolve()
567591
if file_path.suffix.lower() not in {'.hdf', '.h5', '.nxs', '.hdf5'}:
568592
file_path = file_path.with_suffix('.hdf')
569593
_create_folder(str(file_path))
@@ -605,8 +629,7 @@ def open_hdf_stream(file_path, data_shape, key_path='entry/data',
605629
object
606630
hdf object.
607631
"""
608-
file_path = __get_path(file_path, check_exist=False)
609-
file_path = file_path.resolve()
632+
file_path = __get_path(file_path, check_exist=False).resolve()
610633
if file_path.suffix.lower() not in {'.hdf', '.h5', '.nxs', '.hdf5'}:
611634
file_path = file_path.with_suffix('.hdf')
612635
_create_folder(str(file_path))
@@ -631,6 +654,60 @@ def open_hdf_stream(file_path, data_shape, key_path='entry/data',
631654
return data_out
632655

633656

657+
def save_plot_points(file_path, list_points, height, width, overwrite=True,
658+
dpi=100, marker="o", color="blue"):
659+
"""
660+
Save the plot of dot-centroids to an image. Useful to check if the dots
661+
are arranged properly where dots on the same line having the same color.
662+
663+
Parameters
664+
----------
665+
file_path : str
666+
Output file path.
667+
list_points : list of 1D-array
668+
List of the (y-x)-coordinates of points.
669+
height : int
670+
Height of the image.
671+
width : int
672+
Width of the image.
673+
overwrite : bool, optional
674+
Overwrite the existing file if True.
675+
dpi : int, optional
676+
The resolution in dots per inch.
677+
marker : str
678+
Plot marker. Full list is at:
679+
https://matplotlib.org/stable/api/markers_api.html
680+
color : str
681+
Marker color. Full list is at:
682+
https://matplotlib.org/stable/tutorials/colors/colors.html
683+
684+
Returns
685+
-------
686+
str
687+
Updated file path.
688+
"""
689+
file_path = __get_path(file_path, check_exist=False).resolve()
690+
_create_folder(str(file_path))
691+
if not overwrite:
692+
file_path = _create_file_name(str(file_path))
693+
fig = plt.figure(frameon=False)
694+
fig.set_size_inches(width / dpi, height / dpi)
695+
ax = plt.Axes(fig, [0., 0., 1.0, 1.0])
696+
ax.set_axis_off()
697+
fig.add_axes(ax)
698+
plt.axis((0, width, 0, height))
699+
m_size = 0.5 * min(height / dpi, width / dpi)
700+
for point in list_points:
701+
plt.plot(point[1], height - point[0], marker, color=color,
702+
markersize=m_size)
703+
try:
704+
plt.savefig(file_path, dpi=dpi)
705+
except IOError:
706+
raise ValueError("Couldn't write to file {}".format(file_path))
707+
plt.close()
708+
return file_path
709+
710+
634711
def save_metadata_txt(file_path, xcenter, ycenter, list_fact, overwrite=True):
635712
"""
636713
Write metadata to a text file.
@@ -653,8 +730,7 @@ def save_metadata_txt(file_path, xcenter, ycenter, list_fact, overwrite=True):
653730
str
654731
Updated file path.
655732
"""
656-
file_path = __get_path(file_path, check_exist=False)
657-
file_path = file_path.resolve()
733+
file_path = __get_path(file_path, check_exist=False).resolve()
658734
if file_path.suffix.lower() not in {'.txt', '.dat'}:
659735
file_path = file_path.with_suffix('.txt')
660736
_create_folder(str(file_path))
@@ -698,56 +774,78 @@ def load_metadata_txt(file_path):
698774
return xcenter, ycenter, list_fact
699775

700776

701-
def save_plot_points(file_path, list_points, height, width, overwrite=True,
702-
dpi=100, marker="o", color="blue"):
777+
def __numpy_encoder(obj):
778+
if isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
779+
np.int16, np.int32, np.int64, np.uint8,
780+
np.uint16, np.uint32, np.uint64)):
781+
return int(obj)
782+
elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
783+
return float(obj)
784+
elif isinstance(obj, (np.ndarray,)):
785+
return obj.tolist()
786+
raise TypeError(f"Object of type '{type(obj).__name__}' "
787+
f"is not JSON serializable")
788+
789+
790+
def save_metadata_json(file_path, xcenter, ycenter, list_fact, overwrite=True):
703791
"""
704-
Save the plot of dot-centroids to an image. Useful to check if the dots
705-
are arranged properly where dots on the same line having the same color.
792+
Write metadata to a JSON file.
706793
707794
Parameters
708795
----------
709796
file_path : str
710797
Output file path.
711-
list_points : list of 1D-array
712-
List of the (y-x)-coordinates of points.
713-
height : int
714-
Height of the image.
715-
width : int
716-
Width of the image.
798+
xcenter : float
799+
Center of distortion in x-direction.
800+
ycenter : float
801+
Center of distortion in y-direction.
802+
list_fact : list of float
803+
Coefficients of a polynomial.
717804
overwrite : bool, optional
718-
Overwrite the existing file if True.
719-
dpi : int, optional
720-
The resolution in dots per inch.
721-
marker : str
722-
Plot marker. Full list is at:
723-
https://matplotlib.org/stable/api/markers_api.html
724-
color : str
725-
Marker color. Full list is at:
726-
https://matplotlib.org/stable/tutorials/colors/colors.html
805+
Overwrite an existing file if True.
727806
728807
Returns
729808
-------
730809
str
731810
Updated file path.
732811
"""
733-
file_path = __get_path(file_path, check_exist=False)
734-
file_path = file_path.resolve()
812+
# Get resolved file path and set to JSON suffix
813+
file_path = __get_path(file_path, check_exist=False).resolve()
814+
if file_path.suffix.lower() != '.json':
815+
file_path = file_path.with_suffix('.json')
735816
_create_folder(str(file_path))
817+
736818
if not overwrite:
737819
file_path = _create_file_name(str(file_path))
738-
fig = plt.figure(frameon=False)
739-
fig.set_size_inches(width / dpi, height / dpi)
740-
ax = plt.Axes(fig, [0., 0., 1.0, 1.0])
741-
ax.set_axis_off()
742-
fig.add_axes(ax)
743-
plt.axis((0, width, 0, height))
744-
m_size = 0.5 * min(height / dpi, width / dpi)
745-
for point in list_points:
746-
plt.plot(point[1], height - point[0], marker, color=color,
747-
markersize=m_size)
748-
try:
749-
plt.savefig(file_path, dpi=dpi)
750-
except IOError:
751-
raise ValueError("Couldn't write to file {}".format(file_path))
752-
plt.close()
820+
821+
# Create metadata dictionary
822+
metadata = {
823+
'xcenter': float(xcenter),
824+
'ycenter': float(ycenter),
825+
'list_fact': list_fact
826+
}
827+
with open(file_path, "w") as f:
828+
json.dump(metadata, f, indent=4, default=__numpy_encoder)
753829
return file_path
830+
831+
832+
def load_metadata_json(file_path):
833+
"""
834+
Load distortion coefficients from a JSON file.
835+
836+
Parameters
837+
----------
838+
file_path : str
839+
Path to a JSON file.
840+
841+
Returns
842+
-------
843+
tuple of floats and list
844+
Tuple of (xcenter, ycenter, list_fact).
845+
"""
846+
with open(__get_path(file_path), 'r') as f:
847+
metadata = json.load(f)
848+
xcenter = metadata['xcenter']
849+
ycenter = metadata['ycenter']
850+
list_fact = metadata['list_fact']
851+
return xcenter, ycenter, list_fact

examples/example_01.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@
167167
np.abs(corrected_mat - mat0))
168168
io.save_metadata_txt(output_base + "/coefficients_bw.txt", xcenter, ycenter,
169169
list_fact)
170+
# io.save_metadata_json(output_base + "/coefficients_bw.json", xcenter, ycenter,
171+
# list_fact)
170172

171173
# Check the correction results
172174
list_uhor_lines = post.unwarp_line_backward(list_hor_lines, xcenter, ycenter,
@@ -205,6 +207,8 @@
205207
np.abs(corrected_mat - mat0))
206208
io.save_metadata_txt(output_base + "coefficients_fw.txt", xcenter, ycenter,
207209
list_fact)
210+
# io.save_metadata_json(output_base + "coefficients_fw.json", xcenter, ycenter,
211+
# list_fact)
208212

209213
# Check the correction results
210214
list_uhor_lines = post.unwarp_line_forward(list_hor_lines, xcenter, ycenter,
@@ -247,6 +251,8 @@
247251
np.abs(corrected_mat - mat0))
248252
io.save_metadata_txt(
249253
output_base + "/coefficients_bwfw.txt", xcenter, ycenter, list_bfact)
254+
# io.save_metadata_json(
255+
# output_base + "/coefficients_bwfw.json", xcenter, ycenter, list_bfact)
250256
# Check the correction results
251257
list_uhor_lines = post.unwarp_line_backward(
252258
list_hor_lines, xcenter, ycenter, list_bfact)

tests/test_loadersaver.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,19 @@ def test_open_hdf_stream(self):
239239
self.assertRaises(ValueError, f_alias, "./tmp/data/data4.hdf",
240240
(64, 64), options={"energy/entry/data": 25.0})
241241

242+
def test_save_plot_points(self):
243+
f_alias = losa.save_plot_points
244+
list_data = np.ones((64, 2), dtype=np.float32)
245+
file_path = "./tmp/data/plot1.png"
246+
list_data[:, 0] = 0.5 * np.random.rand(64)
247+
list_data[:, 1] = np.arange(64)
248+
f_alias(file_path, list_data, 64, 64, dpi=100)
249+
self.assertTrue(os.path.isfile(file_path))
250+
251+
path = f_alias(file_path, list_data, 64, 64,
252+
dpi=100, overwrite=False)
253+
self.assertTrue(os.path.isfile(path))
254+
242255
def test_save_metadata_txt(self):
243256
f_alias = losa.save_metadata_txt
244257
file_path = "./tmp/data/coef.txt"
@@ -253,21 +266,26 @@ def test_save_metadata_txt(self):
253266
self.assertTrue(os.path.isfile(file_path + ".txt"))
254267

255268
def test_load_metadata_txt(self):
256-
f_alias = losa.load_metadata_txt
257269
file_path = "./tmp/data/coef1.txt"
258270
losa.save_metadata_txt(file_path, 31.0, 32.0, [1.0, 0.0])
259-
(x, y, facts) = f_alias(file_path)
271+
(x, y, facts) = losa.load_metadata_txt(file_path)
260272
self.assertTrue(((x == 31.0) and (y == 32.0)) and facts == [1.0, 0.0])
261273

262-
def test_save_plot_points(self):
263-
f_alias = losa.save_plot_points
264-
list_data = np.ones((64, 2), dtype=np.float32)
265-
file_path = "./tmp/data/plot1.png"
266-
list_data[:, 0] = 0.5 * np.random.rand(64)
267-
list_data[:, 1] = np.arange(64)
268-
f_alias(file_path, list_data, 64, 64, dpi=100)
274+
def test_save_metadata_json(self):
275+
f_alias = losa.save_metadata_json
276+
file_path = "./tmp/data/coef.json"
277+
f_alias(file_path, 31, 32, [1.0, 0.0])
269278
self.assertTrue(os.path.isfile(file_path))
270279

271-
path = f_alias(file_path, list_data, 64, 64,
272-
dpi=100, overwrite=False)
273-
self.assertTrue(os.path.isfile(path))
280+
path = f_alias(file_path, 31, 32, [1.0, 0.0], overwrite=False)
281+
self.assertTrue(path != file_path)
282+
283+
file_path_no_ext = "./tmp/data/coef1"
284+
f_alias(file_path_no_ext, 31, 32, [1.0, 0.0])
285+
self.assertTrue(os.path.isfile(file_path_no_ext + ".json"))
286+
287+
def test_load_metadata_json(self):
288+
file_path = "./tmp/data/coef1.json"
289+
losa.save_metadata_json(file_path, 31.0, 32.0, [1.0, 0.0])
290+
x, y, facts = losa.load_metadata_json(file_path)
291+
self.assertTrue((x == 31.0) and (y == 32.0) and facts == [1.0, 0.0])

0 commit comments

Comments
 (0)