Skip to content

Commit 48f570c

Browse files
authored
Merge pull request #652 from DiamondLightSource/remove-stripe-fw-radway-58
Iterative search method for memory estimator
2 parents 0efd3ac + be1e8c0 commit 48f570c

File tree

6 files changed

+238
-25
lines changed

6 files changed

+238
-25
lines changed

docs/source/pipelines/yaml.rst

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,6 @@ Those pipelines consist of methods from HTTomolibgpu (GPU) and HTTomolib (CPU) b
5858
.. literalinclude:: ../pipelines_full/FISTA3d_tomobar.yaml
5959
:language: yaml
6060

61-
.. _tutorials_pl_templates_cpu:
62-
63-
Pipelines using TomoPy library
64-
------------------------------
65-
66-
One can build CPU-only pipelines by using mostly TomoPy methods. They are expected to be slower than the pipelines above.
67-
68-
.. dropdown:: CPU pipeline using auto-centering and the gridrec reconstruction method from TomoPy.
69-
70-
.. literalinclude:: ../pipelines_full/tomopy_gridrec.yaml
71-
:language: yaml
72-
73-
7461
.. _tutorials_pl_templates_dls:
7562

7663
DLS-specific pipelines
@@ -95,17 +82,31 @@ Pipelines with parameter sweeps
9582

9683
Here we demonstrate how to perform a sweep across multiple values of a single parameter (see :ref:`parameter_sweeping` for more details).
9784

98-
.. note:: There is no need to add image saving plugin for sweep runs as it will be added automatically. It is also preferable to keep the `preview` small as the time of computation can be substantial.
85+
.. note:: There is no need to add image saving plugin for sweep runs as it will be added automatically.
9986

10087
.. dropdown:: Parameter sweep using the :code:`!SweepRange` tag to do a sweep over several CoR values of the :code:`center` parameter in the reconstruction method.
10188

10289
.. literalinclude:: ../pipelines_full/sweep_center_FBP3d_tomobar.yaml
10390
:language: yaml
104-
:emphasize-lines: 34-37
91+
:emphasize-lines: 36-39
10592

10693
.. dropdown:: Parameter sweep using the :code:`!Sweep` tag over several particular values (not a range) of the :code:`ratio_delta_beta` parameter for the Paganin filter.
10794

10895
.. literalinclude:: ../pipelines_full/sweep_paganin_FBP3d_tomobar.yaml
10996
:language: yaml
110-
:emphasize-lines: 53-56
97+
:emphasize-lines: 51-54
11198

99+
100+
.. _tutorials_pl_templates_cpu:
101+
102+
Pipelines using TomoPy library
103+
------------------------------
104+
105+
One can build CPU-only pipelines by using mostly TomoPy methods.
106+
107+
.. note:: Methods from TomoPy are expected to be slower than the GPU-accelerated methods from the libraries above.
108+
109+
.. dropdown:: CPU pipeline using auto-centering and the gridrec reconstruction method from TomoPy.
110+
111+
.. literalinclude:: ../pipelines_full/tomopy_gridrec.yaml
112+
:language: yaml

httomo/method_wrappers/generic.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,14 @@
1414
MethodRepository,
1515
)
1616
from httomo.runner.output_ref import OutputRef
17-
from httomo.utils import catch_gputime, catchtime, gpu_enabled, log_rank, xp
17+
from httomo.utils import (
18+
catch_gputime,
19+
catchtime,
20+
gpu_enabled,
21+
log_rank,
22+
xp,
23+
search_max_slices_iterative,
24+
)
1825

1926

2027
import numpy as np
@@ -447,6 +454,17 @@ def calculate_max_slices(
447454
* np.prod(non_slice_dims_shape)
448455
* data_dtype.itemsize
449456
)
457+
elif self.memory_gpu.method == "iterative":
458+
# The iterative method may use the GPU
459+
assert gpu_enabled, "GPU method used on a system without GPU support"
460+
with xp.cuda.Device(self._gpu_id):
461+
gpumem_cleanup()
462+
return (
463+
self._calculate_max_slices_iterative(
464+
data_dtype, non_slice_dims_shape, available_memory
465+
),
466+
available_memory,
467+
)
450468
else:
451469
(
452470
memory_bytes_method,
@@ -462,6 +480,31 @@ def calculate_max_slices(
462480
available_memory - subtract_bytes
463481
) // memory_bytes_method, available_memory
464482

483+
def _calculate_max_slices_iterative(
484+
self,
485+
data_dtype: np.dtype,
486+
non_slice_dims_shape: Tuple[int, int],
487+
available_memory: int,
488+
) -> int:
489+
def get_mem_bytes(current_slices):
490+
try:
491+
memory_bytes = self._query.calculate_memory_bytes_for_slices(
492+
dims_shape=(
493+
current_slices,
494+
non_slice_dims_shape[0],
495+
non_slice_dims_shape[1],
496+
),
497+
dtype=data_dtype,
498+
**self._unwrap_output_ref_values(),
499+
)
500+
return memory_bytes
501+
except:
502+
return 2**64
503+
finally:
504+
gpumem_cleanup()
505+
506+
return search_max_slices_iterative(available_memory, get_mem_bytes)
507+
465508
def _unwrap_output_ref_values(self) -> Dict[str, Any]:
466509
"""
467510
Iterate through params in `self.config_params` and, for any value of type `OutputRef`,

httomo/runner/methods_repository_interface.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ def calculate_memory_bytes(
5050
"""Calculate the memory required in bytes, returning bytes method and subtract bytes tuple"""
5151
... # pragma: no cover
5252

53+
def calculate_memory_bytes_for_slices(
54+
self, dims_shape: Tuple[int, int, int], dtype: np.dtype, **kwargs
55+
) -> int:
56+
"""Calculate the memory required in bytes for a given 3D grid"""
57+
... # pragma: no cover
58+
5359
def calculate_output_dims(
5460
self, non_slice_dims_shape: Tuple[int, int], **kwargs
5561
) -> Tuple[int, int]:

httomo/sweep_runner/param_sweep_runner.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
from httomo.sweep_runner.param_sweep_block import ParamSweepBlock
1717
from httomo.sweep_runner.side_output_manager import SideOutputManager
1818
from httomo.sweep_runner.stages import NonSweepStage, Stages, SweepStage
19-
from httomo.utils import catchtime, log_exception, log_once
20-
from httomo.runner.gpu_utils import get_available_gpu_memory
19+
from httomo.utils import catchtime, log_exception, log_once, search_max_slices_iterative
20+
from httomo.runner.gpu_utils import get_available_gpu_memory, gpumem_cleanup
2121
from httomo.preview import PreviewConfig, PreviewDimConfig
2222
from httomo.runner.dataset_store_interfaces import DataSetSource
2323
from httomo_backends.methods_database.packages.backends.httomolibgpu.supporting_funcs.prep.phase import (
24-
_calc_memory_bytes_paganin_filter,
24+
_calc_memory_bytes_for_slices_paganin_filter,
2525
)
2626

2727

@@ -322,8 +322,15 @@ def _slices_to_fit_memory_Paganin(source: DataSetSource) -> int:
322322
angles_total = source.aux_data.angles_length
323323
det_X_length = source.chunk_shape[2]
324324

325-
(memory_bytes_method, subtract_bytes) = _calc_memory_bytes_paganin_filter(
326-
(angles_total, det_X_length), dtype=np.float32()
327-
)
325+
def get_mem_bytes(slices):
326+
try:
327+
return _calc_memory_bytes_for_slices_paganin_filter(
328+
(slices, angles_total, det_X_length), dtype=np.float32()
329+
)
330+
except:
331+
return 2**64
332+
finally:
333+
gpumem_cleanup()
328334

329-
return (available_memory - subtract_bytes) // memory_bytes_method
335+
gpumem_cleanup()
336+
return search_max_slices_iterative(available_memory, get_mem_bytes)

httomo/utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,60 @@ def mpi_abort_excepthook(type, value, traceback):
299299
log_rank("\n".join(format_tb(traceback)), MPI.COMM_WORLD)
300300
MPI.COMM_WORLD.Abort()
301301
sys.__excepthook__(type, value, traceback)
302+
303+
304+
def search_max_slices_iterative(
305+
available_memory: int, get_mem_bytes: Callable[[int], int]
306+
) -> int:
307+
"""
308+
Approximates the maximum number of fitting slices to the GPU memory for a given function.
309+
The memory profile of the function must be increasing in the function of the number of slices.
310+
First, a linear approximation of the memory profile is performed. If this is not accurate enough,
311+
a binary search follows to determine the number of fitting slices. This function never returns a
312+
number of slices for what `get_mem_bytes(slices) > available_memory`.
313+
314+
:param available_memory: Bytes of available device memory
315+
:type available_memory: int
316+
:param get_mem_bytes: A functor that produces the bytes of device memory needed for a given number of slices.
317+
:type get_mem_bytes: Callable[[int], int]
318+
:return: Returns the approximation of the maximum number of fitting slices.
319+
:rtype: int
320+
"""
321+
MEM_RATIO_THRESHOLD = 0.9 # 90% of the used device memory is the target
322+
323+
# Find a number of slices that does not fit
324+
current_slices = 100
325+
slices_high = None
326+
memory_bytes = get_mem_bytes(current_slices)
327+
if memory_bytes > available_memory:
328+
# Found upper limit, continue to binary search
329+
slices_high = current_slices
330+
else:
331+
# linear approximation
332+
current_slices = int(current_slices * available_memory / memory_bytes)
333+
while True:
334+
memory_bytes = get_mem_bytes(current_slices)
335+
if memory_bytes > available_memory:
336+
# Found upper limit, continue to binary search
337+
break
338+
elif memory_bytes >= available_memory * MEM_RATIO_THRESHOLD:
339+
# This is "good enough", return
340+
return current_slices
341+
342+
# If linear approximation is not enough, just double every iteration
343+
current_slices *= 2
344+
slices_high = current_slices
345+
346+
# Binary search between low and high
347+
slices_low = 0
348+
while slices_high - slices_low > 1:
349+
current_slices = (slices_low + slices_high) // 2
350+
memory_bytes = get_mem_bytes(current_slices)
351+
if memory_bytes > available_memory:
352+
slices_high = current_slices
353+
elif memory_bytes >= available_memory * MEM_RATIO_THRESHOLD:
354+
return current_slices
355+
else:
356+
slices_low = current_slices
357+
358+
return slices_low

tests/method_wrappers/test_generic.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional, Union
1+
from typing import Callable, List, Optional, Union
22

33
import numpy as np
44
from httomo.method_wrappers import make_method_wrapper
@@ -689,6 +689,105 @@ def test_method(data):
689689
assert available_memory == 1_000_000_000
690690

691691

692+
def _linear_mem(*args, **kwargs):
693+
proj, x, y = kwargs["dims_shape"]
694+
dtype = kwargs["dtype"]
695+
return proj * x * y * dtype.itemsize
696+
697+
698+
def _linear_offset_mem(*args, **kwargs):
699+
proj, x, y = kwargs["dims_shape"]
700+
dtype = kwargs["dtype"]
701+
return (x * y + proj * x * y + proj * x**2) * dtype.itemsize
702+
703+
704+
def _quadratic_mem(*args, **kwargs):
705+
proj, x, y = kwargs["dims_shape"]
706+
dtype = kwargs["dtype"]
707+
return (4 * x * y + proj * proj * x * y) * dtype.itemsize
708+
709+
710+
THROW_OVER_SLICES = 77
711+
712+
713+
def _quadratic_mem_throws(*args, **kwargs):
714+
proj, x, y = kwargs["dims_shape"]
715+
dtype = kwargs["dtype"]
716+
if proj > THROW_OVER_SLICES:
717+
raise Exception("Memory estimator failed")
718+
return (4 * x * y + proj * proj * x * y) * dtype.itemsize
719+
720+
721+
@pytest.mark.cupy
722+
@pytest.mark.parametrize("available_memory", [0, 1_000, 1_000_000, 1_000_000_000])
723+
@pytest.mark.parametrize(
724+
"memcalc_fn",
725+
[_linear_mem, _linear_offset_mem, _quadratic_mem, _quadratic_mem_throws],
726+
)
727+
def test_generic_calculate_max_slices_iterative(
728+
mocker: MockerFixture,
729+
dummy_block: DataSetBlock,
730+
available_memory: int,
731+
memcalc_fn: Callable,
732+
):
733+
class FakeModule:
734+
def test_method(data):
735+
return data
736+
737+
mocker.patch(
738+
"httomo.method_wrappers.generic.import_module", return_value=FakeModule
739+
)
740+
741+
memory_gpu = GpuMemoryRequirement(multiplier=None, method="iterative")
742+
repo = make_mock_repo(
743+
mocker,
744+
pattern=Pattern.projection,
745+
output_dims_change=True,
746+
implementation="gpu_cupy",
747+
memory_gpu=memory_gpu,
748+
)
749+
750+
memcalc_mock = mocker.patch.object(
751+
repo.query("", ""), "calculate_memory_bytes_for_slices", side_effect=memcalc_fn
752+
)
753+
wrp = make_method_wrapper(
754+
repo,
755+
"mocked_module_path",
756+
"test_method",
757+
MPI.COMM_WORLD,
758+
make_mock_preview_config(mocker),
759+
)
760+
shape_t = list(dummy_block.chunk_shape)
761+
shape_t.pop(0)
762+
shape = (shape_t[0], shape_t[1])
763+
max_slices, _ = wrp.calculate_max_slices(
764+
dummy_block.data.dtype,
765+
shape,
766+
available_memory,
767+
)
768+
769+
check_slices = lambda slices: memcalc_mock(
770+
dims_shape=(slices, shape[0], shape[1]), dtype=dummy_block.data.dtype
771+
)
772+
threshold = 0.9
773+
if check_slices(1) > available_memory:
774+
# If zero slice fits
775+
assert max_slices == 0
776+
else:
777+
# Computed slices must fit in the available memory
778+
assert check_slices(max_slices) <= available_memory
779+
780+
if memcalc_fn == _quadratic_mem_throws and max_slices + 1 >= THROW_OVER_SLICES:
781+
with pytest.raises(Exception):
782+
check_slices(max_slices + 1)
783+
else:
784+
# And one more slice must not fit OR above threshold
785+
assert (
786+
check_slices(max_slices + 1) > available_memory
787+
or check_slices(max_slices) >= available_memory * threshold
788+
)
789+
790+
692791
@pytest.mark.cupy
693792
def test_generic_calculate_output_dims(mocker: MockerFixture):
694793
class FakeModule:

0 commit comments

Comments
 (0)