@@ -55,8 +55,8 @@ def __init__(
5555 postprocessing : List [Processing ],
5656 model_adapter : ModelAdapter ,
5757 default_ns : Union [
58- v0_5 .ParameterizedSize . N ,
59- Mapping [Tuple [MemberId , AxisId ], v0_5 .ParameterizedSize . N ],
58+ v0_5 .ParameterizedSize_N ,
59+ Mapping [Tuple [MemberId , AxisId ], v0_5 .ParameterizedSize_N ],
6060 ] = 10 ,
6161 default_batch_size : int = 1 ,
6262 ) -> None :
@@ -179,40 +179,17 @@ def get_output_sample_id(self, input_sample_id: SampleId):
179179 self .model_description .id or self .model_description .name
180180 )
181181
182- def predict_sample_with_blocking (
182+ def predict_sample_with_fixed_blocking (
183183 self ,
184184 sample : Sample ,
185+ input_block_shape : Mapping [MemberId , Mapping [AxisId , int ]],
186+ * ,
185187 skip_preprocessing : bool = False ,
186188 skip_postprocessing : bool = False ,
187- ns : Optional [
188- Union [
189- v0_5 .ParameterizedSize .N ,
190- Mapping [Tuple [MemberId , AxisId ], v0_5 .ParameterizedSize .N ],
191- ]
192- ] = None ,
193- batch_size : Optional [int ] = None ,
194189 ) -> Sample :
195- """predict a sample by splitting it into blocks according to the model and the `ns` parameter"""
196190 if not skip_preprocessing :
197191 self .apply_preprocessing (sample )
198192
199- if isinstance (self .model_description , v0_4 .ModelDescr ):
200- raise NotImplementedError (
201- "predict with blocking not implemented for v0_4.ModelDescr {self.model_description.name}"
202- )
203-
204- ns = ns or self ._default_ns
205- if isinstance (ns , int ):
206- ns = {
207- (ipt .id , a .id ): ns
208- for ipt in self .model_description .inputs
209- for a in ipt .axes
210- if isinstance (a .size , v0_5 .ParameterizedSize )
211- }
212- input_block_shape = self .model_description .get_tensor_sizes (
213- ns , batch_size or self ._default_batch_size
214- ).inputs
215-
216193 n_blocks , input_blocks = sample .split_into_blocks (
217194 input_block_shape ,
218195 halo = self ._default_input_halo ,
@@ -239,6 +216,47 @@ def predict_sample_with_blocking(
239216
240217 return predicted_sample
241218
219+ def predict_sample_with_blocking (
220+ self ,
221+ sample : Sample ,
222+ skip_preprocessing : bool = False ,
223+ skip_postprocessing : bool = False ,
224+ ns : Optional [
225+ Union [
226+ v0_5 .ParameterizedSize_N ,
227+ Mapping [Tuple [MemberId , AxisId ], v0_5 .ParameterizedSize_N ],
228+ ]
229+ ] = None ,
230+ batch_size : Optional [int ] = None ,
231+ ) -> Sample :
232+ """predict a sample by splitting it into blocks according to the model and the `ns` parameter"""
233+
234+ if isinstance (self .model_description , v0_4 .ModelDescr ):
235+ raise NotImplementedError (
236+ "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr"
237+ + f" { self .model_description .name } ."
238+ + " Consider using `predict_sample_with_fixed_blocking`"
239+ )
240+
241+ ns = ns or self ._default_ns
242+ if isinstance (ns , int ):
243+ ns = {
244+ (ipt .id , a .id ): ns
245+ for ipt in self .model_description .inputs
246+ for a in ipt .axes
247+ if isinstance (a .size , v0_5 .ParameterizedSize )
248+ }
249+ input_block_shape = self .model_description .get_tensor_sizes (
250+ ns , batch_size or self ._default_batch_size
251+ ).inputs
252+
253+ return self .predict_sample_with_fixed_blocking (
254+ sample ,
255+ input_block_shape = input_block_shape ,
256+ skip_preprocessing = skip_preprocessing ,
257+ skip_postprocessing = skip_postprocessing ,
258+ )
259+
242260 # def predict(
243261 # self,
244262 # inputs: Predict_IO,
@@ -310,8 +328,8 @@ def create_prediction_pipeline(
310328 ),
311329 model_adapter : Optional [ModelAdapter ] = None ,
312330 ns : Union [
313- v0_5 .ParameterizedSize . N ,
314- Mapping [Tuple [MemberId , AxisId ], v0_5 .ParameterizedSize . N ],
331+ v0_5 .ParameterizedSize_N ,
332+ Mapping [Tuple [MemberId , AxisId ], v0_5 .ParameterizedSize_N ],
315333 ] = 10 ,
316334 ** deprecated_kwargs : Any ,
317335) -> PredictionPipeline :
0 commit comments