|
1 | 1 | import multiprocessing as mp |
| 2 | +import warnings |
2 | 3 | from abc import ABC |
3 | 4 | from concurrent import futures |
4 | 5 | from copy import deepcopy |
@@ -619,20 +620,43 @@ def _compute_mask_data_tiled(self, image_embeddings, i, initial_segmentations, n |
619 | 620 | def initialize( |
620 | 621 | self, |
621 | 622 | image: np.ndarray, |
622 | | - tile_shape: List[int], |
623 | | - halo: List[int], |
624 | 623 | image_embeddings=None, |
625 | 624 | i: Optional[int] = None, |
| 625 | + tile_shape: Optional[List[int]] = None, |
| 626 | + halo: Optional[List[int]] = None, |
626 | 627 | verbose: bool = False, |
627 | 628 | embedding_save_path: Optional[str] = None, |
628 | 629 | ): |
629 | 630 | """ |
630 | 631 | """ |
631 | 632 | original_size = image.shape[:2] |
632 | | - if image_embeddings is None: |
| 633 | + |
| 634 | + have_tiling_params = (tile_shape is not None) and (halo is not None) |
| 635 | + if image_embeddings is None and have_tiling_params: |
| 636 | + if embedding_save_path is None: |
| 637 | + raise ValueError( |
| 638 | + "You have passed neither pre-computed embeddings nor a path for saving embeddings." |
| 639 | + "Embeddings with tiling can only be computed if a save path is given." |
| 640 | + ) |
633 | 641 | image_embeddings = util.precompute_image_embeddings( |
634 | 642 | self.predictor, image, tile_shape=tile_shape, halo=halo, save_path=embedding_save_path |
635 | 643 | ) |
| 644 | + elif image_embeddings is None and not have_tiling_params: |
| 645 | + raise ValueError("You passed neither pre-computed embeddings nor tiling parameters (tile_shape and halo)") |
| 646 | + else: |
| 647 | + feats = image_embeddings["features"] |
| 648 | + tile_shape_, halo_ = feats.attrs["tile_shape"], feats.attrs["halo"] |
| 649 | + if have_tiling_params and ( |
| 650 | + (list(tile_shape) != list(tile_shape_)) or |
| 651 | + (list(halo) != list(halo_)) |
| 652 | + ): |
| 653 | + warnings.warn( |
| 654 | + "You have passed both pre-computed embeddings and tiling parameters (tile_shape and halo) and" |
| 655 | + "the values of the tiling parameters from the embeddings disagree with the ones that were passed." |
| 656 | + "The tiling parameters you have passed wil be ignored." |
| 657 | + ) |
| 658 | + tile_shape = tile_shape_ |
| 659 | + halo = halo_ |
636 | 660 |
|
637 | 661 | tiling = blocking([0, 0], original_size, tile_shape) |
638 | 662 | n_tiles = tiling.numberOfBlocks |
|
0 commit comments