@@ -65,7 +65,6 @@ def process(self, batch, request):
6565
6666class Predict (BlockwiseTask ):
6767 task_type : Literal ["predict" ] = "predict"
68- roi : tuple [PydanticCoordinate , PydanticCoordinate ] | None = None
6968 checkpoint : Annotated [
7069 TorchModel ,
7170 Field (discriminator = "model_type" ),
@@ -75,6 +74,7 @@ class Predict(BlockwiseTask):
7574
7675 fit : Literal ["overhang" ] = "overhang"
7776 read_write_conflict : Literal [False ] = False
77+ context_override : tuple [PydanticCoordinate , PydanticCoordinate ] | None = None
7878 _out_array_dtype : np .dtype = np .dtype (np .uint8 )
7979
8080 @property
@@ -85,7 +85,7 @@ def checkpoint_config(self) -> Model:
8585 def write_roi (self ) -> Roi :
8686 in_data_roi = self .in_data .array ("r" ).roi
8787 if self .roi is not None :
88- return in_data_roi .intersect (Roi ( self .roi [ 0 ], self . roi [ 1 ]) )
88+ return in_data_roi .intersect (self .roi )
8989 else :
9090 return in_data_roi
9191
@@ -98,8 +98,16 @@ def write_size(self) -> Coordinate:
9898 return self .checkpoint_config .eval_output_shape * self .voxel_size
9999
100100 @property
101- def context_size (self ) -> Coordinate :
102- return self .checkpoint_config .context * self .voxel_size
101+ def context_size (self ) -> Coordinate | tuple [Coordinate , Coordinate ]:
102+ context = self .checkpoint_config .context
103+ if isinstance (context , Coordinate ):
104+ return self .checkpoint_config .context * self .voxel_size
105+ elif isinstance (context [0 ], Coordinate ) and isinstance (context [1 ], Coordinate ):
106+ return (context [0 ] * self .voxel_size , context [1 ] * self .voxel_size )
107+ else :
108+ raise NotImplementedError (
109+ f"Unsupported context { context } type: { type (context )} . Expected Coordinate or tuple of Coordinates."
110+ )
103111
104112 @property
105113 def task_name (self ) -> str :
@@ -203,7 +211,7 @@ def process_block_func(self):
203211 for output_key , out_data in zip (output_keys , self .out_data ):
204212 if out_data is not None :
205213 pipeline += ArrayWrite (
206- output_key , out_data .array ("a" ), self .checkpoint .to_uint8
214+ output_key , Dataset .array (out_data , "a" ), self .checkpoint .to_uint8
207215 )
208216
209217 print ("Starting prediction..." )
0 commit comments