@@ -393,7 +393,9 @@ def validate(self, model, dataloader, max_batches):
393393 output = reduce_distributed_output (output , len (self .data_parallel_device_ids ))
394394
395395 elif self .single_gpu :
396- output = model (data_batch .cuda (self .data_parallel_device_ids [0 ]), batch_i )
396+ gpu_id = self .data_parallel_device_ids [0 ]
397+ data_batch = [x .cuda (gpu_id ) for x in data_batch if isinstance (x , torch .Tensor )]
398+ output = model (data_batch , batch_i )
397399
398400 else :
399401 output = model .validation_step (data_batch , batch_i )
@@ -474,7 +476,7 @@ def fit(self, model):
474476 self .__dp_train (model )
475477
476478 elif self .single_gpu :
477- self .__single_gpu_train (model )\
479+ self .__single_gpu_train (model )
478480
479481 # ON CPU
480482 else :
@@ -846,7 +848,10 @@ def __run_tng_batch(self, data_batch, batch_nb):
846848 output = self .model (data_batch , batch_nb )
847849 output = reduce_distributed_output (output , len (self .data_parallel_device_ids ))
848850 elif self .single_gpu :
851+ gpu_id = self .data_parallel_device_ids [0 ]
852+ data_batch = [x .cuda (gpu_id ) for x in data_batch if isinstance (x , torch .Tensor )]
849853 output = self .model (data_batch .cuda (self .data_parallel_device_ids [0 ]), batch_nb )
854+
850855 else :
851856 output = self .model .training_step (data_batch , batch_nb )
852857
0 commit comments