Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 31 additions & 38 deletions micro_sam/training/sam_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,24 +118,6 @@ def _compute_iou(self, pred, true, eps=1e-7):
iou = overlap / (union + eps)
return iou

def preprocess_one_hot_masks(self, y_one_hot):
"""Converts the masks to match the "low_res_masks" shape.
"""
# Convert the labels to "low_res_mask" shape
# First step is to use the logic from `ResizeLongestSide` to resize the longest side.
target_length = self.model.transform.target_length
target_shape = self.model.transform.get_preprocess_shape(y_one_hot.shape[2], y_one_hot.shape[3], target_length)
y_one_hot = F.interpolate(input=y_one_hot, size=target_shape)
# Next, we pad the remaining region to (1024, 1024)
h, w = y_one_hot.shape[-2:]
padh = self.model.sam.image_encoder.img_size - h
padw = self.model.sam.image_encoder.img_size - w
y_one_hot = F.pad(input=y_one_hot, pad=(0, padw, 0, padh))
# Finally, let's resize the labels to the desired shape (i.e. (256, 256))
y_one_hot = F.interpolate(input=y_one_hot, size=(256, 256))

return y_one_hot

def _compute_loss(self, batched_outputs, y_one_hot):
"""Compute the loss for one iteration. The loss is made up of two components:
- The mask loss: dice score between the predicted masks and targets.
Expand All @@ -145,9 +127,6 @@ def _compute_loss(self, batched_outputs, y_one_hot):

# Loop over the batch.
for batch_output, targets in zip(batched_outputs, y_one_hot):
# Let's convert the inputs to the match the expected "low_res_masks" shape.
targets = self.preprocess_one_hot_masks(targets)

predicted_objects = torch.sigmoid(batch_output["low_res_masks"])
# Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop).
# We swap the axes that go into the dice loss so that the object axis
Expand Down Expand Up @@ -182,8 +161,7 @@ def _compute_loss(self, batched_outputs, y_one_hot):
#

def _get_best_masks(self, batched_outputs, batched_iou_predictions):
# Batched mask and logit (low-res mask) predictions.
masks = torch.stack([m["masks"] for m in batched_outputs])
# Batched logit (low-res mask) predictions.
logits = torch.stack([m["low_res_masks"] for m in batched_outputs])

# Determine the best IOU across the multi-object prediction axis
Expand All @@ -197,17 +175,13 @@ def _get_best_masks(self, batched_outputs, batched_iou_predictions):
# Note that we squash the first two axes (batch x objects) into one when indexing.
# That's why we need to reshape bax into (batch x objects) using a view.
# We also keep the multi object axis as a singleton, that's why the view has (batch_size, n_objects, 1, ...)
batch_size, n_objects = masks.shape[:2]
h, w = masks.shape[-2:]
masks = masks[best_iou_idx].view(batch_size, n_objects, 1, h, w)

batch_size, n_objects = logits.shape[:2]
h, w = logits.shape[-2:]
logits = logits[best_iou_idx].view(batch_size, n_objects, 1, h, w)

# Binarize the mask. Note that the mask here also contains logits, so we use 0.0
# as threshold instead of using 0.5. (Hence we don't need to apply a sigmoid)
masks = (masks > 0.0).float()
return masks, logits
# Binarize the logit.
logits = (logits > 0.0).float()
return logits

def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multimask_output):
"""Compute the loss for several (sub-)iterations of iterative prompting.
Expand Down Expand Up @@ -239,10 +213,10 @@ def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multim

# Determine the next prompts based on current predictions.
with torch.no_grad():
# Get the mask and logit predictions corresponding to the predicted object
# Get the logit predictions corresponding to the predicted object
# (per actual object) with the best IOU.
masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions)
batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits)
logits = self._get_best_masks(batched_outputs, batched_iou_predictions)
batched_inputs = self._update_prompts(batched_inputs, y_one_hot, logits)

loss = loss / num_subiter
mask_loss = mask_loss / num_subiter
Expand All @@ -251,17 +225,17 @@ def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multim

return loss, mask_loss, iou_regression_loss, mean_model_iou

def _update_prompts(self, batched_inputs, y_one_hot, masks, logits_masks):
def _update_prompts(self, batched_inputs, y_one_hot, logits_masks):
# here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs")
for x1, x2, _inp, logits in zip(masks, y_one_hot, batched_inputs, logits_masks):
for y_, _inp, logits in zip(y_one_hot, batched_inputs, logits_masks):
# here, we get each object in the pairs and do the point choices per-object
net_coords, net_labels, _, _ = self.prompt_generator(x2, x1)
net_coords, net_labels, _, _ = self.prompt_generator(segmentation=y_, prediction=logits)

# convert the point coordinates to the expected resolution for iterative prompting
# NOTE:
# - "only" need to transform the point prompts from the iterative prompting
# - the `logits` are the low res masks (256, 256), hence do not need the transform
net_coords = self.model.transform.apply_coords_torch(net_coords, y_one_hot.shape[-2:])
net_coords = self.model.transform.apply_coords_torch(net_coords, _inp["original_size"])

updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \
if "point_coords" in _inp.keys() else net_coords
Expand All @@ -285,6 +259,24 @@ def _update_prompts(self, batched_inputs, y_one_hot, masks, logits_masks):
# Training Loop
#

def _preprocess_labels(self, y):
"""Converts the masks to match the shape of "low_res_masks".
"""
# Convert the labels to "low_res_mask" shape
# First step is to use the logic from `ResizeLongestSide` to resize the longest side.
target_length = self.model.transform.target_length
target_shape = self.model.transform.get_preprocess_shape(y.shape[2], y.shape[3], target_length)
y = F.interpolate(input=y, size=target_shape)
# Next, we pad the remaining region to (1024, 1024)
h, w = y.shape[-2:]
padh = self.model.sam.image_encoder.img_size - h
padw = self.model.sam.image_encoder.img_size - w
y = F.pad(input=y, pad=(0, padw, 0, padh))
# Finally, let's resize the labels to the desired shape (i.e. (256, 256))
y = F.interpolate(input=y, size=(256, 256))

return y

def _preprocess_batch(self, batched_inputs, y, sampled_ids):
"""Compute one hot target (one mask per channel) for the sampled ids
and restrict the number of sampled objects to the minimal number in the batch.
Expand All @@ -297,6 +289,7 @@ def _preprocess_batch(self, batched_inputs, y, sampled_ids):
# number of objects across the batch.
n_objects = min(len(ids) for ids in sampled_ids)

y = self._preprocess_labels(y)
y = y.to(self.device)
# Compute the one hot targets for the seg-id.
y_one_hot = torch.stack([
Expand Down
2 changes: 1 addition & 1 deletion micro_sam/training/trainable_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
self,
sam: Sam,
device: Union[str, torch.device],
upsampled_masks: bool = True,
upsampled_masks: bool = False,
) -> None:
super().__init__()
self.sam = sam
Expand Down