@@ -573,9 +573,9 @@ def __getitem__(self, idx):
573573 img = torch .from_numpy (self .x [idx ])
574574
575575 target = {}
576- target ["boxes" ] = torch .from_numpy (y [idx ]["boxes" ])
577- target ["labels" ] = torch .from_numpy (y [idx ]["labels" ])
578- target ["scores" ] = torch .from_numpy (y [idx ]["scores" ])
576+ target ["boxes" ] = torch .from_numpy (self . y [idx ]["boxes" ])
577+ target ["labels" ] = torch .from_numpy (self . y [idx ]["labels" ])
578+ target ["scores" ] = torch .from_numpy (self . y [idx ]["scores" ])
579579 mask_i = torch .from_numpy (self .mask [idx ])
580580
581581 return img , target , mask_i
@@ -600,19 +600,21 @@ def __getitem__(self, idx):
600600 if isinstance (target , torch .Tensor ):
601601 target = target .to (self .estimator .device )
602602 else :
603- target ["boxes" ] = target ["boxes" ].to (self .estimator .device )
604- target ["labels" ] = target ["labels" ].to (self .estimator .device )
605- target ["scores" ] = target ["scores" ].to (self .estimator .device )
603+ target ["boxes" ] = target ["boxes" ][0 ].to (self .estimator .device )
604+ target ["labels" ] = target ["labels" ][0 ].to (self .estimator .device )
605+ target ["scores" ] = target ["scores" ][0 ].to (self .estimator .device )
606+ target = [target ]
606607 _ = self ._train_step (images = images , target = target , mask = None )
607608 else :
608609 for images , target , mask_i in data_loader :
609610 images = images .to (self .estimator .device )
610611 if isinstance (target , torch .Tensor ):
611612 target = target .to (self .estimator .device )
612613 else :
613- target ["boxes" ] = target ["boxes" ].to (self .estimator .device )
614- target ["labels" ] = target ["labels" ].to (self .estimator .device )
615- target ["scores" ] = target ["scores" ].to (self .estimator .device )
614+ target ["boxes" ] = target ["boxes" ][0 ].to (self .estimator .device )
615+ target ["labels" ] = target ["labels" ][0 ].to (self .estimator .device )
616+ target ["scores" ] = target ["scores" ][0 ].to (self .estimator .device )
617+ target = [target ]
616618 mask_i = mask_i .to (self .estimator .device )
617619 _ = self ._train_step (images = images , target = target , mask = mask_i )
618620
0 commit comments