|
1 | | -from typing import List, Optional, Union |
| 1 | +from typing import Callable, List, Optional, Union |
2 | 2 |
|
3 | 3 | import numpy as np |
4 | 4 | from httomo.method_wrappers import make_method_wrapper |
@@ -689,6 +689,105 @@ def test_method(data): |
689 | 689 | assert available_memory == 1_000_000_000 |
690 | 690 |
|
691 | 691 |
|
| 692 | +def _linear_mem(*args, **kwargs): |
| 693 | + proj, x, y = kwargs["dims_shape"] |
| 694 | + dtype = kwargs["dtype"] |
| 695 | + return proj * x * y * dtype.itemsize |
| 696 | + |
| 697 | + |
| 698 | +def _linear_offset_mem(*args, **kwargs): |
| 699 | + proj, x, y = kwargs["dims_shape"] |
| 700 | + dtype = kwargs["dtype"] |
| 701 | + return (x * y + proj * x * y + proj * x**2) * dtype.itemsize |
| 702 | + |
| 703 | + |
| 704 | +def _quadratic_mem(*args, **kwargs): |
| 705 | + proj, x, y = kwargs["dims_shape"] |
| 706 | + dtype = kwargs["dtype"] |
| 707 | + return (4 * x * y + proj * proj * x * y) * dtype.itemsize |
| 708 | + |
| 709 | + |
| 710 | +THROW_OVER_SLICES = 77 |
| 711 | + |
| 712 | + |
| 713 | +def _quadratic_mem_throws(*args, **kwargs): |
| 714 | + proj, x, y = kwargs["dims_shape"] |
| 715 | + dtype = kwargs["dtype"] |
| 716 | + if proj > THROW_OVER_SLICES: |
| 717 | + raise Exception("Memory estimator failed") |
| 718 | + return (4 * x * y + proj * proj * x * y) * dtype.itemsize |
| 719 | + |
| 720 | + |
| 721 | +@pytest.mark.cupy |
| 722 | +@pytest.mark.parametrize("available_memory", [0, 1_000, 1_000_000, 1_000_000_000]) |
| 723 | +@pytest.mark.parametrize( |
| 724 | + "memcalc_fn", |
| 725 | + [_linear_mem, _linear_offset_mem, _quadratic_mem, _quadratic_mem_throws], |
| 726 | +) |
| 727 | +def test_generic_calculate_max_slices_iterative( |
| 728 | + mocker: MockerFixture, |
| 729 | + dummy_block: DataSetBlock, |
| 730 | + available_memory: int, |
| 731 | + memcalc_fn: Callable, |
| 732 | +): |
| 733 | + class FakeModule: |
| 734 | + def test_method(data): |
| 735 | + return data |
| 736 | + |
| 737 | + mocker.patch( |
| 738 | + "httomo.method_wrappers.generic.import_module", return_value=FakeModule |
| 739 | + ) |
| 740 | + |
| 741 | + memory_gpu = GpuMemoryRequirement(multiplier=None, method="iterative") |
| 742 | + repo = make_mock_repo( |
| 743 | + mocker, |
| 744 | + pattern=Pattern.projection, |
| 745 | + output_dims_change=True, |
| 746 | + implementation="gpu_cupy", |
| 747 | + memory_gpu=memory_gpu, |
| 748 | + ) |
| 749 | + |
| 750 | + memcalc_mock = mocker.patch.object( |
| 751 | + repo.query("", ""), "calculate_memory_bytes_for_slices", side_effect=memcalc_fn |
| 752 | + ) |
| 753 | + wrp = make_method_wrapper( |
| 754 | + repo, |
| 755 | + "mocked_module_path", |
| 756 | + "test_method", |
| 757 | + MPI.COMM_WORLD, |
| 758 | + make_mock_preview_config(mocker), |
| 759 | + ) |
| 760 | + shape_t = list(dummy_block.chunk_shape) |
| 761 | + shape_t.pop(0) |
| 762 | + shape = (shape_t[0], shape_t[1]) |
| 763 | + max_slices, _ = wrp.calculate_max_slices( |
| 764 | + dummy_block.data.dtype, |
| 765 | + shape, |
| 766 | + available_memory, |
| 767 | + ) |
| 768 | + |
| 769 | + check_slices = lambda slices: memcalc_mock( |
| 770 | + dims_shape=(slices, shape[0], shape[1]), dtype=dummy_block.data.dtype |
| 771 | + ) |
| 772 | + threshold = 0.9 |
| 773 | + if check_slices(1) > available_memory: |
| 774 | + # If zero slice fits |
| 775 | + assert max_slices == 0 |
| 776 | + else: |
| 777 | + # Computed slices must fit in the available memory |
| 778 | + assert check_slices(max_slices) <= available_memory |
| 779 | + |
| 780 | + if memcalc_fn == _quadratic_mem_throws and max_slices + 1 >= THROW_OVER_SLICES: |
| 781 | + with pytest.raises(Exception): |
| 782 | + check_slices(max_slices + 1) |
| 783 | + else: |
| 784 | + # And one more slice must not fit OR above threshold |
| 785 | + assert ( |
| 786 | + check_slices(max_slices + 1) > available_memory |
| 787 | + or check_slices(max_slices) >= available_memory * threshold |
| 788 | + ) |
| 789 | + |
| 790 | + |
692 | 791 | @pytest.mark.cupy |
693 | 792 | def test_generic_calculate_output_dims(mocker: MockerFixture): |
694 | 793 | class FakeModule: |
|
0 commit comments