Skip to content

Commit 8c6a2ff

Browse files
edyoshikunziw-liu
andauthored
Patch ModelSummary with BatchTransforms (#307)
* adding the fix to run model summary * remove redundant branch * Revert "remove redundant branch" This reverts commit 1a06cc1. * remove fixme comment --------- Co-authored-by: Ziwen Liu <[email protected]>
1 parent 03cbc6a commit 8c6a2ff

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

viscy/data/combined.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -231,15 +231,22 @@ def val_dataloader(self):
231231

232232
def on_after_batch_transfer(self, batch, dataloader_idx: int):
233233
"""Apply GPU transforms from constituent data modules to micro-batches."""
234+
if not isinstance(batch, list):
235+
return batch
236+
234237
processed_micro_batches = []
235238
for micro_batch in batch:
236-
dataset_idx = micro_batch.pop("_dataset_idx")
237-
dm = self.data_modules[dataset_idx]
238-
if hasattr(dm, "on_after_batch_transfer"):
239-
processed_micro_batch = dm.on_after_batch_transfer(
240-
micro_batch, dataloader_idx
241-
)
239+
if isinstance(micro_batch, dict) and "_dataset_idx" in micro_batch:
240+
dataset_idx = micro_batch.pop("_dataset_idx")
241+
dm = self.data_modules[dataset_idx]
242+
if hasattr(dm, "on_after_batch_transfer"):
243+
processed_micro_batch = dm.on_after_batch_transfer(
244+
micro_batch, dataloader_idx
245+
)
246+
else:
247+
processed_micro_batch = micro_batch
242248
else:
249+
# Handle case where micro_batch doesn't have _dataset_idx (e.g., from model summary)
243250
processed_micro_batch = micro_batch
244251
processed_micro_batches.append(processed_micro_batch)
245252
combined_batch = {}

0 commit comments

Comments
 (0)