Skip to content

Commit 690a308

Browse files
authored
Merge pull request #313 from dmgav/dask-fix
Fix issues with recent versions of Dask and Numba
2 parents cec2606 + b7f7f48 commit 690a308

File tree

8 files changed

+152
-49
lines changed

8 files changed

+152
-49
lines changed

.github/workflows/testing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
matrix:
1717
host-os: ["ubuntu-latest", "macos-latest", "windows-latest"]
1818
python-version: ["3.9", "3.10", "3.11"]
19-
numpy-version: ["1.24"]
19+
numpy-version: ["1.26"]
2020
pyqt-version: ["5.15"]
2121
include:
2222
- host-os: "ubuntu-latest"
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# The original code for serializers/deserializers can be found in
2+
# 'distributed/protocols/h5py.py'
3+
4+
import distributed.protocol.h5py # noqa: F401
5+
from distributed.protocol.serialize import dask_deserialize, dask_serialize
6+
7+
deserialized_files = set()
8+
9+
10+
def serialize_h5py_file(f):
11+
if f and (f.mode != "r"):
12+
raise ValueError("Can only serialize read-only h5py files")
13+
filename = f.filename if f else None
14+
return {"filename": filename}, []
15+
16+
17+
def serialize_h5py_dataset(x):
18+
header, _ = serialize_h5py_file(x.file if x else None)
19+
header["name"] = x.name if x else None
20+
return header, []
21+
22+
23+
def deserialize_h5py_file(header, frames):
24+
import h5py
25+
26+
filename = header["filename"]
27+
if filename:
28+
file = h5py.File(filename, mode="r")
29+
deserialized_files.add(file)
30+
else:
31+
file = None
32+
return file
33+
34+
35+
def deserialize_h5py_dataset(header, frames):
36+
file = deserialize_h5py_file(header, frames)
37+
name = header["name"]
38+
dset = file[name] if (file and name) else None
39+
return dset
40+
41+
42+
def dask_set_custom_serializers():
43+
import h5py
44+
45+
dask_serialize.register((h5py.Group, h5py.Dataset), serialize_h5py_dataset)
46+
dask_serialize.register(h5py.File, serialize_h5py_file)
47+
dask_deserialize.register((h5py.Group, h5py.Dataset), deserialize_h5py_dataset)
48+
dask_deserialize.register(h5py.File, deserialize_h5py_file)
49+
50+
51+
def dask_close_all_files():
52+
while deserialized_files:
53+
file = deserialized_files.pop()
54+
if file:
55+
file.close()

pyxrf/core/map_processing.py

Lines changed: 88 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from numba import jit
1515
from progress.bar import Bar
1616

17+
from .dask_h5py_serializers import dask_close_all_files, dask_set_custom_serializers
1718
from .fitting import fit_spectrum
1819

1920
logger = logging.getLogger(__name__)
@@ -519,6 +520,9 @@ def compute_total_spectrum(
519520
else:
520521
client_is_local = False
521522

523+
client.run(dask_set_custom_serializers)
524+
dask_set_custom_serializers()
525+
522526
n_workers = len(client.scheduler_info()["workers"])
523527
logger.info(f"Dask distributed client: {n_workers} workers")
524528

@@ -536,9 +540,12 @@ def compute_total_spectrum(
536540
if file_obj:
537541
file_obj.close()
538542

539-
# The following code is needed to cause Dask 'distributed>=2021.7.0' to close the h5file.
540-
del result_fut
541-
_dask_release_file_descriptors(client=client)
543+
client.run(dask_close_all_files)
544+
dask_close_all_files()
545+
546+
# # The following code is needed to cause Dask 'distributed>=2021.7.0' to close the h5file.
547+
# del result_fut
548+
# _dask_release_file_descriptors(client=client)
542549

543550
if client_is_local:
544551
client.close()
@@ -614,6 +621,9 @@ def compute_total_spectrum_and_count(
614621
else:
615622
client_is_local = False
616623

624+
client.run(dask_set_custom_serializers)
625+
dask_set_custom_serializers()
626+
617627
n_workers = len(client.scheduler_info()["workers"])
618628
logger.info(f"Dask distributed client: {n_workers} workers")
619629

@@ -632,9 +642,12 @@ def compute_total_spectrum_and_count(
632642
if file_obj:
633643
file_obj.close()
634644

635-
# The following code is needed to cause Dask 'distributed>=2021.7.0' to close the h5file.
636-
del result_fut
637-
_dask_release_file_descriptors(client=client)
645+
client.run(dask_close_all_files)
646+
dask_close_all_files()
647+
648+
# # The following code is needed to cause Dask 'distributed>=2021.7.0' to close the h5file.
649+
# del result_fut
650+
# _dask_release_file_descriptors(client=client)
638651

639652
if client_is_local:
640653
client.close()
@@ -710,30 +723,30 @@ def _fit_xrf_block(data, data_sel_indices, matv, snip_param, use_snip):
710723
return data_out
711724

712725

713-
def _dask_release_file_descriptors(*, client):
714-
"""
715-
Make sure the Dask Client releases descriptors of the HDF5 files opened in read-only mode
716-
so that they could be opened for reading.
717-
"""
718-
# Runs small task on Dask client. Starting from v2021.7.0, Dask Distributed does not always
719-
# close HDF5 files, that are open in read-only mode for loading raw data. Submitting and
720-
# computing a small unrelated tasks seem to prompt the client to release the resources from
721-
# the previous task and close the files.
722-
rfut = da.sum(da.random.random((1000,), chunks=(10,))).persist(scheduler=client)
723-
rfut.compute(scheduler=client)
726+
# def _dask_release_file_descriptors(*, client):
727+
# """
728+
# Make sure the Dask Client releases descriptors of the HDF5 files opened in read-only mode
729+
# so that they could be opened for reading.
730+
# """
731+
# # Runs small task on Dask client. Starting from v2021.7.0, Dask Distributed does not always
732+
# # close HDF5 files, that are open in read-only mode for loading raw data. Submitting and
733+
# # computing a small unrelated tasks seem to prompt the client to release the resources from
734+
# # the previous task and close the files.
735+
# rfut = da.sum(da.random.random((1000,), chunks=(10,))).persist(scheduler=client)
736+
# rfut.compute(scheduler=client)
724737

725-
current_os = platform.system()
726-
if current_os == "Linux":
727-
# Starting with Dask/Distributed version 2022.2.0 the following step is required:
728-
# https://distributed.dask.org/en/stable/worker-memory.html#manually-trim-memory
729-
# (works for Linux only, there are different solutions for other OS if needed)
730-
import ctypes
738+
# current_os = platform.system()
739+
# if current_os == "Linux":
740+
# # Starting with Dask/Distributed version 2022.2.0 the following step is required:
741+
# # https://distributed.dask.org/en/stable/worker-memory.html#manually-trim-memory
742+
# # (works for Linux only, there are different solutions for other OS if needed)
743+
# import ctypes
731744

732-
def trim_memory() -> int:
733-
libc = ctypes.CDLL("libc.so.6")
734-
return libc.malloc_trim(0)
745+
# def trim_memory() -> int:
746+
# libc = ctypes.CDLL("libc.so.6")
747+
# return libc.malloc_trim(0)
735748

736-
client.run(trim_memory)
749+
# client.run(trim_memory)
737750

738751

739752
def fit_xrf_map(
@@ -857,6 +870,9 @@ def fit_xrf_map(
857870
else:
858871
client_is_local = False
859872

873+
client.run(dask_set_custom_serializers)
874+
dask_set_custom_serializers()
875+
860876
n_workers = len(client.scheduler_info()["workers"])
861877
logger.info(f"Dask distributed client: {n_workers} workers")
862878

@@ -881,9 +897,12 @@ def fit_xrf_map(
881897
if data_is_from_file:
882898
file_obj.close()
883899

884-
# The following code is needed to cause Dask 'distributed>=2021.7.0' to close the h5file.
885-
del result_fut
886-
_dask_release_file_descriptors(client=client)
900+
client.run(dask_close_all_files)
901+
dask_close_all_files()
902+
903+
# # The following code is needed to cause Dask 'distributed>=2021.7.0' to close the h5file.
904+
# del result_fut
905+
# _dask_release_file_descriptors(client=client)
887906

888907
if client_is_local:
889908
client.close()
@@ -1070,6 +1089,9 @@ def compute_selected_rois(
10701089
else:
10711090
client_is_local = False
10721091

1092+
client.run(dask_set_custom_serializers)
1093+
dask_set_custom_serializers()
1094+
10731095
n_workers = len(client.scheduler_info()["workers"])
10741096
logger.info(f"Dask distributed client: {n_workers} workers")
10751097

@@ -1102,9 +1124,12 @@ def compute_selected_rois(
11021124
if file_obj:
11031125
file_obj.close()
11041126

1127+
client.run(dask_close_all_files)
1128+
dask_close_all_files()
1129+
11051130
# The following code is needed to cause Dask 'distributed>=2021.7.0' to close the h5file.
1106-
del result_fut
1107-
_dask_release_file_descriptors(client=client)
1131+
# del result_fut
1132+
# _dask_release_file_descriptors(client=client)
11081133

11091134
if client_is_local:
11101135
client.close()
@@ -1237,14 +1262,37 @@ def snip_method_numba(
12371262
# where there are peaks. On the boundary part, we don't care
12381263
# the accuracy so much. But we need to pay attention to edge
12391264
# effects in general convolution.
1240-
A = s.sum()
1241-
1242-
background = np.convolve(background, s) / A
1243-
# Trim 'background' array to imitate the np.convolve option 'mode="same"'
1244-
mg = len(s) - 1
1245-
n_beg = mg // 2
1246-
n_end = n_beg - mg # Negative
1247-
background = background[n_beg:n_end]
1265+
1266+
def convolve(background, s):
1267+
# Modifies the contents of the 'background' array.
1268+
# This implementation of convolution replaces the original
1269+
# implementation based on 'np.convolve'. Seems to work as fast
1270+
# as the original implementation.
1271+
s_len = len(s)
1272+
n_beg = (s_len - 1) // 2
1273+
A = s.sum()
1274+
source = np.hstack(
1275+
(
1276+
np.zeros(n_beg, dtype=background.dtype),
1277+
background,
1278+
np.zeros(s_len - n_beg, dtype=background.dtype),
1279+
)
1280+
)
1281+
for n in range(len(background)):
1282+
background[n] = np.sum(source[n : n + s_len] * s) / A
1283+
1284+
convolve(background, s)
1285+
1286+
# # The following implementation of convolution stopped working because of
1287+
# # unclear issues with 'np.convolve' (gave 'List index out of range' error),
1288+
# # The code is left for reference.
1289+
# A = s.sum()
1290+
# background = np.convolve(background, s) / A
1291+
# # Trim 'background' array to imitate the np.convolve option 'mode="same"'
1292+
# mg = len(s) - 1
1293+
# n_beg = mg // 2
1294+
# n_end = n_beg - mg # Negative
1295+
# background = background[n_beg:n_end]
12481296

12491297
window_p = width * fwhm / e_lin
12501298
if spectral_binning is not None and spectral_binning > 0:

pyxrf/gui_module/wnd_load_quant_calibration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def display_standard_selection_table(self):
344344
ttip = f"Fluorescence (F): {fluorescence:12g}\nDensity (D): {density:12g}\n"
345345
# Avoid very small values of density (probably zero)
346346
if abs(density) > 1e-30:
347-
ttip += f"F/D: {fluorescence/density:12g}"
347+
ttip += f"F/D: {fluorescence / density:12g}"
348348

349349
item.setToolTip(ttip)
350350

pyxrf/model/command_tools.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def get_positions_set(img_dict):
285285
channel_num = len(param_channel_list)
286286
for i in range(channel_num):
287287
inner_path = "xrfmap/" + det_channel_names[i]
288-
print(f"Processing data from detector channel {det_channel_names[i]} (#{i+1}) ...")
288+
print(f"Processing data from detector channel {det_channel_names[i]} (#{i + 1}) ...")
289289

290290
# load param file
291291
param_file_name = param_channel_list[i]
@@ -663,7 +663,7 @@ def pyxrf_batch(
663663
# only ``start_id`` is specified:
664664
# process only one file that contains ``start_id`` in its name
665665
# (only if such file exists)
666-
pattern = f"^[^_]*_{str(start_id)}\D+" # noqa: W605
666+
pattern = f"^[^_]*_{str(start_id)}\\D+" # noqa: W605
667667
flist = [fname for fname in all_files if re.search(pattern, os.path.basename(fname))]
668668

669669
if len(flist) < 1:
@@ -679,7 +679,7 @@ def pyxrf_batch(
679679
# select files, which contain the respective ID substring in their names
680680
flist = []
681681
for data_id in range(start_id, end_id + 1):
682-
pattern = f"^[^_]*_{str(data_id)}\D+" # noqa: W605
682+
pattern = f"^[^_]*_{str(data_id)}\\D+" # noqa: W605
683683
flist += [fname for fname in all_files if re.search(pattern, os.path.basename(fname))]
684684
if len(flist) < 1:
685685
print(f"No files with Scan IDs in the range {start_id} .. {end_id} were found.")

pyxrf/model/fit_spectrum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,7 @@ def save2Dmap_to_hdf(self, *, calculation_info=None, pixel_fit="nnls"):
774774
# det1, det2, ... , i.e. 'det' followed by integer number.
775775
# The channel name is always located at the end of the ``data_title``.
776776
# If the channel name is found, then build the path using this name.
777-
srch = re.search("det\d+$", self.data_title) # noqa: W605
777+
srch = re.search(r"det\d+$", self.data_title) # noqa: W605
778778
if srch:
779779
det_name = srch.group(0)
780780
fit_name = f"{prefix_fname}_{det_name}_fit"

pyxrf/model/roi_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def saveROImap_to_hdf(self, data_dict_roi):
346346
# det1, det2, ... , i.e. 'det' followed by integer number.
347347
# The channel name is always located at the end of the ``data_title``.
348348
# If the channel name is found, then build the path using this name.
349-
srch = re.search("det\d+$", self.data_title) # noqa: W605
349+
srch = re.search(r"det\d+$", self.data_title) # noqa: W605
350350
if srch:
351351
det_name = srch.group(0)
352352
inner_path = f"xrfmap/{det_name}"

pyxrf/xanes_maps/xanes_maps_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3381,12 +3381,12 @@ def _save_xanes_maps_to_tiff(
33813381
print(f" image size (Ny, Nx): ({n_y_pixels}, {n_x_pixels})", file=f_log)
33823382
print(
33833383
f" Y-axis scan range [Y_min, Y_max, abs(Y_max-Y_min)]: "
3384-
f"[{y_min:.5g}, {y_max:.5g}, {abs(y_max-y_min):.5g}]",
3384+
f"[{y_min:.5g}, {y_max:.5g}, {abs(y_max - y_min):.5g}]",
33853385
file=f_log,
33863386
)
33873387
print(
33883388
f" X-axis scan range [X_min, X_max, abs(X_max-X_min)]: "
3389-
f"[{x_min:.5g}, {x_max:.5g}, {abs(x_max-x_min):.5g}]",
3389+
f"[{x_min:.5g}, {x_max:.5g}, {abs(x_max - x_min):.5g}]",
33903390
file=f_log,
33913391
)
33923392

0 commit comments

Comments
 (0)