Skip to content

Commit c3893fd

Browse files
committed
Merge branch 'bugfix'
Avoid custom handling of PydanticCoordinate/Roi tuples. These are now directly cast to the appropriate types. Handle writing to datasets that have lazy ops associated with them. We now write directly to the source array.
2 parents 723fcfa + 8dfaccc commit c3893fd

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

src/volara_torch/blockwise/predict.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def process(self, batch, request):
6565

6666
class Predict(BlockwiseTask):
6767
task_type: Literal["predict"] = "predict"
68-
roi: tuple[PydanticCoordinate, PydanticCoordinate] | None = None
6968
checkpoint: Annotated[
7069
TorchModel,
7170
Field(discriminator="model_type"),
@@ -75,6 +74,7 @@ class Predict(BlockwiseTask):
7574

7675
fit: Literal["overhang"] = "overhang"
7776
read_write_conflict: Literal[False] = False
77+
context_override: tuple[PydanticCoordinate, PydanticCoordinate] | None = None
7878
_out_array_dtype: np.dtype = np.dtype(np.uint8)
7979

8080
@property
@@ -85,7 +85,7 @@ def checkpoint_config(self) -> Model:
8585
def write_roi(self) -> Roi:
8686
in_data_roi = self.in_data.array("r").roi
8787
if self.roi is not None:
88-
return in_data_roi.intersect(Roi(self.roi[0], self.roi[1]))
88+
return in_data_roi.intersect(self.roi)
8989
else:
9090
return in_data_roi
9191

@@ -98,8 +98,16 @@ def write_size(self) -> Coordinate:
9898
return self.checkpoint_config.eval_output_shape * self.voxel_size
9999

100100
@property
101-
def context_size(self) -> Coordinate:
102-
return self.checkpoint_config.context * self.voxel_size
101+
def context_size(self) -> Coordinate | tuple[Coordinate, Coordinate]:
102+
context = self.checkpoint_config.context
103+
if isinstance(context, Coordinate):
104+
return self.checkpoint_config.context * self.voxel_size
105+
elif isinstance(context[0], Coordinate) and isinstance(context[1], Coordinate):
106+
return (context[0] * self.voxel_size, context[1] * self.voxel_size)
107+
else:
108+
raise NotImplementedError(
109+
f"Unsupported context {context} type: {type(context)}. Expected Coordinate or tuple of Coordinates."
110+
)
103111

104112
@property
105113
def task_name(self) -> str:
@@ -203,7 +211,7 @@ def process_block_func(self):
203211
for output_key, out_data in zip(output_keys, self.out_data):
204212
if out_data is not None:
205213
pipeline += ArrayWrite(
206-
output_key, out_data.array("a"), self.checkpoint.to_uint8
214+
output_key, Dataset.array(out_data, "a"), self.checkpoint.to_uint8
207215
)
208216

209217
print("Starting prediction...")

src/volara_torch/models.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,32 @@ class Model(StrictBaseModel, ABC):
4444
The range of the output data. This is used to convert between `np.uint8` and
4545
`np.float32` data types for efficient reading/writing of model outputs.
4646
"""
47+
context_override: tuple[PydanticCoordinate, PydanticCoordinate] | None = None
48+
"""
49+
An optional override for asymetrical context sizes. This lets you define
50+
a specific lower and upper context size for the model which must equal
51+
the expected context size of input_shape - output_shape
52+
"""
4753

4854
@property
49-
def context(self) -> Coordinate:
55+
def context(self) -> Coordinate | tuple[Coordinate, Coordinate]:
5056
"""
5157
The context required to make tile artifact free predictions
5258
with this model.
5359
"""
54-
return (self.eval_input_shape - self.eval_output_shape) // 2
60+
expected_context = self.min_input_shape - self.min_output_shape
61+
if self.context_override is not None:
62+
assert (
63+
self.context_override[0] + self.context_override[1]
64+
) == expected_context, (
65+
f"Expected context override {self.context_override} to sum to {expected_context}, "
66+
f"but got {self.context_override[0] + self.context_override[1]}"
67+
)
68+
69+
if self.context_override is not None:
70+
return self.context_override
71+
else:
72+
return expected_context // 2
5573

5674
@abstractmethod
5775
def model(self) -> torch.nn.Module:

0 commit comments

Comments
 (0)