@@ -231,15 +231,22 @@ def val_dataloader(self):
231
231
232
232
def on_after_batch_transfer (self , batch , dataloader_idx : int ):
233
233
"""Apply GPU transforms from constituent data modules to micro-batches."""
234
+ if not isinstance (batch , list ):
235
+ return batch
236
+
234
237
processed_micro_batches = []
235
238
for micro_batch in batch :
236
- dataset_idx = micro_batch .pop ("_dataset_idx" )
237
- dm = self .data_modules [dataset_idx ]
238
- if hasattr (dm , "on_after_batch_transfer" ):
239
- processed_micro_batch = dm .on_after_batch_transfer (
240
- micro_batch , dataloader_idx
241
- )
239
+ if isinstance (micro_batch , dict ) and "_dataset_idx" in micro_batch :
240
+ dataset_idx = micro_batch .pop ("_dataset_idx" )
241
+ dm = self .data_modules [dataset_idx ]
242
+ if hasattr (dm , "on_after_batch_transfer" ):
243
+ processed_micro_batch = dm .on_after_batch_transfer (
244
+ micro_batch , dataloader_idx
245
+ )
246
+ else :
247
+ processed_micro_batch = micro_batch
242
248
else :
249
+ # Handle case where micro_batch doesn't have _dataset_idx (e.g., from model summary)
243
250
processed_micro_batch = micro_batch
244
251
processed_micro_batches .append (processed_micro_batch )
245
252
combined_batch = {}
0 commit comments