Skip to content

Commit 1accd31

Browse files
authored
Merge branch 'main' into radway_69_iterative_lprec_estimator
2 parents 331b96b + 1c29573 commit 1accd31

17 files changed

+357
-221
lines changed

httomo/darks_flats.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,38 @@ def get_separate(config: DarksFlatsFileConfig):
144144
preview_config.detector_x.start : preview_config.detector_x.stop,
145145
]
146146

147+
if (
148+
darks_config.ignore
149+
and not flats_config.ignore
150+
and darks_config.file != flats_config.file
151+
):
152+
flats = get_separate(flats_config)
153+
darks = np.zeros(
154+
(
155+
1,
156+
preview_config.detector_y.stop - preview_config.detector_y.start,
157+
preview_config.detector_x.stop - preview_config.detector_x.start,
158+
),
159+
dtype=flats.dtype,
160+
)
161+
return darks, flats
162+
163+
if (
164+
not darks_config.ignore
165+
and flats_config.ignore
166+
and darks_config.file != flats_config.file
167+
):
168+
darks = get_separate(darks_config)
169+
flats = np.ones(
170+
(
171+
1,
172+
preview_config.detector_y.stop - preview_config.detector_y.start,
173+
preview_config.detector_x.stop - preview_config.detector_x.start,
174+
),
175+
dtype=darks.dtype,
176+
)
177+
return darks, flats
178+
147179
if darks_config.file != flats_config.file:
148180
darks = get_separate(darks_config)
149181
flats = get_separate(flats_config)

httomo/data/dataset_store.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,10 @@ def __init__(
304304

305305
source.finalize()
306306

307+
@property
308+
def padding(self) -> Tuple[int, int]:
309+
return self._padding
310+
307311
@property
308312
def aux_data(self) -> AuxiliaryData:
309313
return self._aux_data

httomo/loaders/standard_tomo_loader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ def __init__(
9191
self._log_info()
9292
weakref.finalize(self, self.finalize)
9393

94+
@property
95+
def padding(self) -> Tuple[int, int]:
96+
return self._padding
97+
9498
@property
9599
def dtype(self) -> np.dtype:
96100
return self._data.dtype

httomo/runner/block_split.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import math
21
from httomo.runner.dataset import DataSetBlock
32
from httomo.runner.dataset_store_interfaces import DataSetSource
43

@@ -24,9 +23,31 @@ class BlockSplitter:
2423

2524
def __init__(self, source: DataSetSource, max_slices: int):
2625
self._source = source
27-
self._chunk_size = source.chunk_shape[source.slicing_dim]
26+
self._chunk_size = (
27+
source.chunk_shape[source.slicing_dim]
28+
+ source.padding[0]
29+
+ source.padding[1]
30+
)
2831
self._max_slices = int(min(max_slices, self._chunk_size))
29-
self._num_blocks = math.ceil(self._chunk_size / self._max_slices)
32+
33+
core_slices = self._max_slices - source.padding[0] - source.padding[1]
34+
step = core_slices
35+
num_blocks = 0
36+
start_of_next_block = 0
37+
while (
38+
start_of_next_block + source.padding[0] + core_slices + source.padding[1]
39+
< self._chunk_size
40+
):
41+
num_blocks += 1
42+
start_of_next_block += step
43+
44+
if (
45+
start_of_next_block + source.padding[0] + core_slices + source.padding[1]
46+
>= self._chunk_size
47+
):
48+
num_blocks += 1
49+
50+
self._num_blocks = num_blocks
3051
self._current = 0
3152
assert self._source.slicing_dim in [
3253
0,
@@ -35,24 +56,23 @@ def __init__(self, source: DataSetSource, max_slices: int):
3556

3657
@property
3758
def slices_per_block(self) -> int:
38-
return self._max_slices
59+
return self._max_slices - self._source.padding[0] - self._source.padding[1]
3960

4061
def __len__(self):
4162
return self._num_blocks
4263

4364
def __getitem__(self, idx: int) -> DataSetBlock:
4465
start = idx * self.slices_per_block
45-
if start >= self._chunk_size:
66+
if (
67+
start
68+
>= self._chunk_size - self._source.padding[0] - self._source.padding[1]
69+
):
4670
raise IndexError("Index out of bounds")
47-
len = min(self.slices_per_block, self._chunk_size - start)
71+
len = min(
72+
self.slices_per_block,
73+
self._chunk_size
74+
- self._source.padding[0]
75+
- self._source.padding[1]
76+
- start,
77+
)
4878
return self._source.read_block(start, len)
49-
50-
def __iter__(self):
51-
return self
52-
53-
def __next__(self) -> DataSetBlock:
54-
if self._current >= len(self):
55-
raise StopIteration
56-
v = self[self._current]
57-
self._current += 1
58-
return v

httomo/runner/dataset_store_interfaces.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,13 @@ def finalize(self):
9696
to give implementations a chance to close files, free memory, etc."""
9797
... # pragma: no cover
9898

99+
@property
100+
def padding(self) -> Tuple[int, int]:
101+
"""
102+
Padding present in `DataSetBlock`'s produced by the `DataSetSource`.
103+
"""
104+
... # pragma: no cover
105+
99106

100107
class DataSetSink(Protocol):
101108
@property

httomo/runner/task_runner.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,6 @@ def _execute_section(self, section: Section, section_index: int = 0):
135135
slicing_dim_section: Literal[0, 1] = _get_slicing_dim(section.pattern) - 1 # type: ignore
136136
self.determine_max_slices(section, slicing_dim_section, self.source.aux_data.get_angles())
137137

138-
# Account for potential padding in number of max slices
139-
padding = determine_section_padding(section)
140-
section.max_slices -= padding[0] + padding[1]
141-
142138
self._log_pipeline(
143139
f"Maximum amount of slices is {section.max_slices} for section {section_index}",
144140
level=logging.DEBUG,
@@ -260,9 +256,9 @@ def _execute_section_block(
260256
) -> DataSetBlock:
261257
if_previous_block_is_on_gpu = False
262258
convert_gpu_block_to_cpu = False
259+
if_current_block_is_on_gpu = False
263260

264261
for ind, method in enumerate(section):
265-
if_current_block_is_on_gpu = False
266262
if method.implementation == "gpu_cupy":
267263
if_current_block_is_on_gpu = True
268264
if method.method_name == "calculate_stats" and if_previous_block_is_on_gpu:
@@ -410,7 +406,9 @@ def determine_max_slices(self, section: Section, slicing_dim: int, angles: np.nd
410406
assert len(section) > 0, "Section should contain at least 1 method"
411407

412408
data_shape = self.source.chunk_shape
413-
max_slices = data_shape[slicing_dim]
409+
max_slices = (
410+
data_shape[slicing_dim] + self.source.padding[0] + self.source.padding[1]
411+
)
414412
# loop over all methods in section
415413
has_gpu = False
416414
for idx, m in enumerate(section):
@@ -463,6 +461,20 @@ def determine_max_slices(self, section: Section, slicing_dim: int, angles: np.nd
463461
max_slices_methods[idx] = min(max_slices, slices_estimated)
464462
non_slice_dims_shape = output_dims
465463

464+
if (
465+
min(max_slices_methods)
466+
< 1 + self.source.padding[0] + self.source.padding[1]
467+
):
468+
padded_method = next(
469+
method.method_name for method in section.methods if method.padding
470+
)
471+
err_str = (
472+
"Unable to process data due to GPU memory limitations.\n"
473+
f"Please remove method '{padded_method}' from the pipeline, or run on a "
474+
"machine with more GPU memory."
475+
)
476+
raise ValueError(err_str)
477+
466478
section.max_slices = min(max_slices_methods)
467479

468480
def _pass_min_block_length_to_intermediate_data_wrapper(self, section: Section):

httomo/ui_layer.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import h5py
1010
from mpi4py.MPI import Comm
1111

12-
from httomo.preview import PreviewConfig, PreviewDimConfig
12+
from httomo.preview import PreviewConfig
1313
from httomo.runner.method_wrapper import MethodWrapper
1414
from httomo.runner.pipeline import Pipeline
1515

@@ -85,7 +85,6 @@ def build_pipeline(self) -> Pipeline:
8585
self._append_methods_list(
8686
i, task_conf, methods_list, parameters, method_id_map
8787
)
88-
fix_preview_y_if_smaller_than_padding(loader, methods_list)
8988
return Pipeline(loader=loader, methods=methods_list)
9089

9190
def _append_methods_list(
@@ -169,27 +168,6 @@ def _setup_loader(self) -> LoaderInterface:
169168
return loader
170169

171170

172-
def fix_preview_y_if_smaller_than_padding(
173-
loader: LoaderInterface, methods_list: List[MethodWrapper]
174-
) -> None:
175-
vertical_preview_length = (
176-
loader.preview.detector_y.stop - loader.preview.detector_y.start
177-
) // loader.comm.size
178-
max_pad_value = 0
179-
for _, m in enumerate(methods_list):
180-
if m.padding:
181-
max_pad_value = max(sum(m.calculate_padding()), max_pad_value)
182-
if max_pad_value >= vertical_preview_length:
183-
loader.preview = PreviewConfig(
184-
angles=loader.preview.angles,
185-
detector_y=PreviewDimConfig(
186-
start=loader.preview.detector_y.start - max_pad_value // 2,
187-
stop=loader.preview.detector_y.stop + max_pad_value // 2,
188-
),
189-
detector_x=loader.preview.detector_x,
190-
)
191-
192-
193171
def get_valid_ref_str(parameters: Dict[str, Any]) -> Dict[str, str]:
194172
"""Find valid reference strings inside dictionary
195173

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ def compare_tif(files_list_to_compare: list, file_path_to_references: list):
535535
Image.open(tif_files_references[index])
536536
)
537537
res_norm = np.linalg.norm(res_images.flatten())
538-
assert res_norm < 1e-6
538+
assert res_norm < 1e-3
539539

540540

541541
def change_value_parameters_method_pipeline(

0 commit comments

Comments
 (0)