@@ -600,23 +600,29 @@ def __getitem__(self, idx):
600600 if isinstance (target , torch .Tensor ):
601601 target = target .to (self .estimator .device )
602602 else :
603- target ["boxes" ] = target ["boxes" ][0 ].to (self .estimator .device )
604- target ["labels" ] = target ["labels" ][0 ].to (self .estimator .device )
605- target ["scores" ] = target ["scores" ][0 ].to (self .estimator .device )
606- target = [target ]
607- _ = self ._train_step (images = images , target = target , mask = None )
603+ targets = []
604+ for idx in range (target ['boxes' ].shape [0 ]):
605+ targets .append ({
606+ 'boxes' : target ['boxes' ][idx ].to (self .estimator .device ),
607+ 'labels' : target ['labels' ][idx ].to (self .estimator .device ),
608+ 'scores' : target ['scores' ][idx ].to (self .estimator .device ),
609+ })
610+ _ = self ._train_step (images = images , target = targets , mask = None )
608611 else :
609612 for images , target , mask_i in data_loader :
610613 images = images .to (self .estimator .device )
611614 if isinstance (target , torch .Tensor ):
612615 target = target .to (self .estimator .device )
613616 else :
614- target ["boxes" ] = target ["boxes" ][0 ].to (self .estimator .device )
615- target ["labels" ] = target ["labels" ][0 ].to (self .estimator .device )
616- target ["scores" ] = target ["scores" ][0 ].to (self .estimator .device )
617- target = [target ]
617+ targets = []
618+ for idx in range (target ['boxes' ].shape [0 ]):
619+ targets .append ({
620+ 'boxes' : target ['boxes' ][idx ].to (self .estimator .device ),
621+ 'labels' : target ['labels' ][idx ].to (self .estimator .device ),
622+ 'scores' : target ['scores' ][idx ].to (self .estimator .device ),
623+ })
618624 mask_i = mask_i .to (self .estimator .device )
619- _ = self ._train_step (images = images , target = target , mask = mask_i )
625+ _ = self ._train_step (images = images , target = targets , mask = mask_i )
620626
621627 # Write summary
622628 if self .summary_writer is not None : # pragma: no cover
0 commit comments