Skip to content

Commit 811af5a

Browse files
author
Dimitar Tasev
authored
Merge branch 'master' into 774_remove_free_all_by_default
2 parents f0df1f9 + b860eec commit 811af5a

File tree

3 files changed

+33
-10
lines changed

3 files changed

+33
-10
lines changed

mantidimaging/core/parallel/test/utility_test.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
# SPDX - License - Identifier: GPL-3.0-or-later
33

44
import mock
5+
import numpy as np
6+
import SharedArray as sa
57

6-
from mantidimaging.core.parallel.utility import multiprocessing_necessary, execute_impl
8+
from mantidimaging.core.parallel.utility import (create_array, create_shared_name, execute_impl,
9+
free_all_owned_by_this_instance, multiprocessing_necessary)
710

811

912
def test_correctly_chooses_parallel():
@@ -38,6 +41,25 @@ def test_execute_impl_par(mock_pool):
3841
assert mock_progress.update.call_count == 15
3942

4043

44+
def test_free_all_owned_by_this_instance():
45+
name1 = create_shared_name()
46+
name2 = create_shared_name()
47+
name3 = create_shared_name()
48+
create_array((10, 10), np.float32, name=name1)
49+
create_array((10, 10), np.float32, name=name2)
50+
create_array((10, 10), np.float32, name=name3)
51+
52+
temp_name = "not_this_instance"
53+
sa.create("not_this_instance", (10, 10))
54+
55+
free_all_owned_by_this_instance()
56+
assert name1 not in [arr.name.decode("utf-8") for arr in sa.list()]
57+
assert name2 not in [arr.name.decode("utf-8") for arr in sa.list()]
58+
assert name3 not in [arr.name.decode("utf-8") for arr in sa.list()]
59+
assert temp_name in [arr.name.decode("utf-8") for arr in sa.list()]
60+
sa.delete(temp_name)
61+
62+
4163
if __name__ == "__main__":
4264
import pytest
4365

mantidimaging/core/parallel/utility.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@
2424

2525
NP_DTYPE = Type[np.single]
2626

27+
INSTANCE_PREFIX = str(uuid.uuid4())
28+
29+
30+
def free_all_owned_by_this_instance():
31+
for arr in [array for array in sa.list() if array.name.decode("utf-8").startswith(INSTANCE_PREFIX)]:
32+
sa.delete(arr.name.decode("utf-8"))
33+
2734

2835
def has_other_shared_arrays() -> bool:
2936
return len(sa.list()) > 0
@@ -35,7 +42,7 @@ def free_all():
3542

3643

3744
def create_shared_name(file_name=None) -> str:
38-
return f"{uuid.uuid4()}{f'-{os.path.basename(file_name)}' if file_name is not None else ''}"
45+
return f"{INSTANCE_PREFIX}-{uuid.uuid4()}{f'-{os.path.basename(file_name)}' if file_name is not None else ''}"
3946

4047

4148
def delete_shared_array(name, silent_failure=False):

mantidimaging/main.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
import logging
88
import warnings
99

10-
import SharedArray as sa
11-
1210
from mantidimaging import helper as h
11+
from mantidimaging.core.parallel.utility import free_all_owned_by_this_instance
1312
from mantidimaging.core.utility.optional_imports import safe_import
1413

1514
formatwarning_orig = warnings.formatwarning
@@ -42,11 +41,7 @@ def parse_args():
4241

4342

4443
def main():
45-
def free_all():
46-
for arr in sa.list():
47-
sa.delete(arr.name.decode("utf-8"))
48-
49-
atexit.register(free_all)
44+
atexit.register(free_all_owned_by_this_instance)
5045
args = parse_args()
5146
# Print version number and exit
5247
if args.version:
@@ -57,7 +52,6 @@ def free_all():
5752

5853
h.initialise_logging(logging.getLevelName(args.log_level))
5954
startup_checks()
60-
free_all()
6155

6256
from mantidimaging import gui
6357
gui.execute()

0 commit comments

Comments
 (0)