diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index eb23c1c66..12ef21cb0 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -118,24 +118,6 @@ def _compute_iou(self, pred, true, eps=1e-7): iou = overlap / (union + eps) return iou - def preprocess_one_hot_masks(self, y_one_hot): - """Converts the masks to match the "low_res_masks" shape. - """ - # Convert the labels to "low_res_mask" shape - # First step is to use the logic from `ResizeLongestSide` to resize the longest side. - target_length = self.model.transform.target_length - target_shape = self.model.transform.get_preprocess_shape(y_one_hot.shape[2], y_one_hot.shape[3], target_length) - y_one_hot = F.interpolate(input=y_one_hot, size=target_shape) - # Next, we pad the remaining region to (1024, 1024) - h, w = y_one_hot.shape[-2:] - padh = self.model.sam.image_encoder.img_size - h - padw = self.model.sam.image_encoder.img_size - w - y_one_hot = F.pad(input=y_one_hot, pad=(0, padw, 0, padh)) - # Finally, let's resize the labels to the desired shape (i.e. (256, 256)) - y_one_hot = F.interpolate(input=y_one_hot, size=(256, 256)) - - return y_one_hot - def _compute_loss(self, batched_outputs, y_one_hot): """Compute the loss for one iteration. The loss is made up of two components: - The mask loss: dice score between the predicted masks and targets. @@ -145,9 +127,6 @@ def _compute_loss(self, batched_outputs, y_one_hot): # Loop over the batch. for batch_output, targets in zip(batched_outputs, y_one_hot): - # Let's convert the inputs to the match the expected "low_res_masks" shape. - targets = self.preprocess_one_hot_masks(targets) - predicted_objects = torch.sigmoid(batch_output["low_res_masks"]) # Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop). # We swap the axes that go into the dice loss so that the object axis @@ -182,8 +161,7 @@ def _compute_loss(self, batched_outputs, y_one_hot): # def _get_best_masks(self, batched_outputs, batched_iou_predictions): - # Batched mask and logit (low-res mask) predictions. - masks = torch.stack([m["masks"] for m in batched_outputs]) + # Batched logit (low-res mask) predictions. logits = torch.stack([m["low_res_masks"] for m in batched_outputs]) # Determine the best IOU across the multi-object prediction axis @@ -197,17 +175,13 @@ def _get_best_masks(self, batched_outputs, batched_iou_predictions): # Note that we squash the first two axes (batch x objects) into one when indexing. # That's why we need to reshape bax into (batch x objects) using a view. # We also keep the multi object axis as a singleton, that's why the view has (batch_size, n_objects, 1, ...) - batch_size, n_objects = masks.shape[:2] - h, w = masks.shape[-2:] - masks = masks[best_iou_idx].view(batch_size, n_objects, 1, h, w) - + batch_size, n_objects = logits.shape[:2] h, w = logits.shape[-2:] logits = logits[best_iou_idx].view(batch_size, n_objects, 1, h, w) - # Binarize the mask. Note that the mask here also contains logits, so we use 0.0 - # as threshold instead of using 0.5. (Hence we don't need to apply a sigmoid) - masks = (masks > 0.0).float() - return masks, logits + # Binarize the logit. + logits = (logits > 0.0).float() + return logits def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multimask_output): """Compute the loss for several (sub-)iterations of iterative prompting. @@ -239,10 +213,10 @@ def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multim # Determine the next prompts based on current predictions. with torch.no_grad(): - # Get the mask and logit predictions corresponding to the predicted object + # Get the logit predictions corresponding to the predicted object # (per actual object) with the best IOU. - masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions) - batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits) + logits = self._get_best_masks(batched_outputs, batched_iou_predictions) + batched_inputs = self._update_prompts(batched_inputs, y_one_hot, logits) loss = loss / num_subiter mask_loss = mask_loss / num_subiter @@ -251,17 +225,17 @@ def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multim return loss, mask_loss, iou_regression_loss, mean_model_iou - def _update_prompts(self, batched_inputs, y_one_hot, masks, logits_masks): + def _update_prompts(self, batched_inputs, y_one_hot, logits_masks): # here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs") - for x1, x2, _inp, logits in zip(masks, y_one_hot, batched_inputs, logits_masks): + for y_, _inp, logits in zip(y_one_hot, batched_inputs, logits_masks): # here, we get each object in the pairs and do the point choices per-object - net_coords, net_labels, _, _ = self.prompt_generator(x2, x1) + net_coords, net_labels, _, _ = self.prompt_generator(segmentation=y_, prediction=logits) # convert the point coordinates to the expected resolution for iterative prompting # NOTE: # - "only" need to transform the point prompts from the iterative prompting # - the `logits` are the low res masks (256, 256), hence do not need the transform - net_coords = self.model.transform.apply_coords_torch(net_coords, y_one_hot.shape[-2:]) + net_coords = self.model.transform.apply_coords_torch(net_coords, _inp["original_size"]) updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \ if "point_coords" in _inp.keys() else net_coords @@ -285,6 +259,24 @@ def _update_prompts(self, batched_inputs, y_one_hot, masks, logits_masks): # Training Loop # + def _preprocess_labels(self, y): + """Converts the masks to match the shape of "low_res_masks". + """ + # Convert the labels to "low_res_mask" shape + # First step is to use the logic from `ResizeLongestSide` to resize the longest side. + target_length = self.model.transform.target_length + target_shape = self.model.transform.get_preprocess_shape(y.shape[2], y.shape[3], target_length) + y = F.interpolate(input=y, size=target_shape) + # Next, we pad the remaining region to (1024, 1024) + h, w = y.shape[-2:] + padh = self.model.sam.image_encoder.img_size - h + padw = self.model.sam.image_encoder.img_size - w + y = F.pad(input=y, pad=(0, padw, 0, padh)) + # Finally, let's resize the labels to the desired shape (i.e. (256, 256)) + y = F.interpolate(input=y, size=(256, 256)) + + return y + def _preprocess_batch(self, batched_inputs, y, sampled_ids): """Compute one hot target (one mask per channel) for the sampled ids and restrict the number of sampled objects to the minimal number in the batch. @@ -297,6 +289,7 @@ def _preprocess_batch(self, batched_inputs, y, sampled_ids): # number of objects across the batch. n_objects = min(len(ids) for ids in sampled_ids) + y = self._preprocess_labels(y) y = y.to(self.device) # Compute the one hot targets for the seg-id. y_one_hot = torch.stack([ diff --git a/micro_sam/training/trainable_sam.py b/micro_sam/training/trainable_sam.py index 214092372..2ddba2aa5 100644 --- a/micro_sam/training/trainable_sam.py +++ b/micro_sam/training/trainable_sam.py @@ -21,7 +21,7 @@ def __init__( self, sam: Sam, device: Union[str, torch.device], - upsampled_masks: bool = True, + upsampled_masks: bool = False, ) -> None: super().__init__() self.sam = sam