66from torch .utils .data import DataLoader
77from tqdm import tqdm
88
9- from cellseg_models_pytorch .inference import Inferer
9+ from cellseg_models_pytorch .models . base . _base_model_inst import BaseModelInst
1010from cellseg_models_pytorch .torch_datasets import WSIDatasetInfer
1111from cellseg_models_pytorch .wsi import SlideReader
1212from cellseg_models_pytorch .wsi .inst_merger import InstMerger
1717class WsiSegmenter :
1818 def __init__ (
1919 self ,
20- inferer : Inferer ,
2120 reader : SlideReader ,
21+ model : BaseModelInst ,
2222 level : int ,
2323 coordinates : List [Tuple [int , int , int , int ]],
2424 batch_size : int = 8 ,
@@ -27,9 +27,6 @@ def __init__(
2727 """Class for segmenting WSIs.
2828
2929 Parameters:
30- inferer (Inferer):
31- The initialized Inferer object for segmenting the WSIs. Can be either
32- `Inferer` or `SlidingWindowInferer`.
3330 reader (SlideReader):
3431 The `SlideReader` object for reading the WSIs.
3532 level (int):
@@ -43,7 +40,7 @@ def __init__(
4340 """
4441 self .batch_size = batch_size
4542 self .coordinates = coordinates
46- self .inferer = inferer
43+ self .model = model
4744
4845 self .dataset = WSIDatasetInfer (
4946 reader , coordinates , level = level , transform = normalization
@@ -87,10 +84,10 @@ def segment(self, save_dir: str, maptype: str = "amap") -> None:
8784 coords = [tuple (map (int , coord )) for coord in coords ]
8885
8986 # predict
90- probs = self .inferer .predict (im )
87+ probs = self .model .predict (im )
9188
9289 # post-process
93- self .inferer .post_process (
90+ self .model .post_process (
9491 probs ,
9592 dst = save_paths ,
9693 coords = coords ,
0 commit comments