diff --git a/blob_loss.py b/blob_loss.py index 7e1b52a..ff97183 100644 --- a/blob_loss.py +++ b/blob_loss.py @@ -26,17 +26,17 @@ def compute_compound_loss( weight = entry["weight"] sigmoid = entry["sigmoid"] - if blob_loss_mode == False: + if blob_loss_mode: vprint("computing main loss!") - if sigmoid == True: + if sigmoid: sigmoid_network_outputs = torch.sigmoid(raw_network_outputs) individual_loss = criterion(sigmoid_network_outputs, label) else: individual_loss = criterion(raw_network_outputs, label) - elif blob_loss_mode == True: + elif blob_loss_mode: vprint("computing blob loss!") - if masked == True: # this is the default blob loss - if sigmoid == True: + if masked: # this is the default blob loss + if sigmoid: sigmoid_network_outputs = torch.sigmoid(raw_network_outputs) individual_loss = compute_blob_loss_multi( criterion=criterion, @@ -49,8 +49,8 @@ def compute_compound_loss( network_outputs=raw_network_outputs, multi_label=label, ) - elif masked == False: # without masking for ablation study - if sigmoid == True: + elif not masked: # without masking for ablation study + if sigmoid: sigmoid_network_outputs = torch.sigmoid(raw_network_outputs) individual_loss = compute_no_masking_multi( criterion=criterion, @@ -78,97 +78,68 @@ def compute_blob_loss_multi( multi_label: torch.Tensor, ): """ - 1. loop through elements in our batch - 2. loop through blobs per element compute loss and divide by blobs to have element loss - 2.1 we need to account for sigmoid and non/sigmoid in conjunction with BCE - 3. divide by batch length to have a correct batch loss for back prop + Compute blob loss for multi-label classification. """ batch_length = multi_label.shape[0] vprint("batch_length:", batch_length) element_blob_loss = [] - # loop over elements - for element in range(batch_length): - if element < batch_length: - end_index = element + 1 - elif element == batch_length: - end_index = None - - element_label = multi_label[element:end_index, ...] - vprint("element label shape:", element_label.shape) - - vprint("element_label:", element_label.shape) - - element_output = network_outputs[element:end_index, ...] - - # loop through labels - unique_labels = torch.unique(element_label) - blob_count = len(unique_labels) - 1 - vprint("found this amount of blobs in batch element:", blob_count) - - label_loss = [] - for ula in unique_labels: - if ula == 0: - vprint("ula is 0 we do nothing") - else: - # first we need one hot labels - vprint("ula greater than 0:", ula) - label_mask = element_label > 0 - # we flip labels - label_mask = ~label_mask - - # we set the mask to true where our label of interest is located - # vprint(torch.count_nonzero(label_mask)) - label_mask[element_label == ula] = 1 - # vprint(torch.count_nonzero(label_mask)) - vprint("label_mask", label_mask) - # vprint("torch.unique(label_mask):", torch.unique(label_mask)) - - the_label = element_label == ula - the_label_int = the_label.int() - vprint("the_label:", torch.count_nonzero(the_label)) - - - # debugging - # masked_label = the_label * label_mask - # vprint("masked_label:", torch.count_nonzero(masked_label)) - - masked_output = element_output * label_mask - - try: - # we try with int labels first, but some losses require floats - blob_loss = criterion(masked_output, the_label_int) - except: - # if int does not work we try float - blob_loss = criterion(masked_output, the_label.float()) - vprint("blob_loss:", blob_loss) - - label_loss.append(blob_loss) - # compute mean - vprint("label_loss:", label_loss) - # mean_label_loss = 0 - vprint("blobs in crop:", len(label_loss)) - if not len(label_loss) == 0: - mean_label_loss = sum(label_loss) / len(label_loss) - # mean_label_loss = sum(label_loss) / \ - # torch.count_nonzero(label_loss) - vprint("mean_label_loss", mean_label_loss) - element_blob_loss.append(mean_label_loss) - - # compute mean - vprint("element_blob_loss:", element_blob_loss) - mean_element_blob_loss = 0 - vprint("elements in batch:", len(element_blob_loss)) - if not len(element_blob_loss) == 0: - mean_element_blob_loss = sum(element_blob_loss) / len(element_blob_loss) - # element_blob_loss) / torch.count_nonzero(element_blob_loss) + # Process the entire batch at once + unique_labels = torch.unique(multi_label) + blob_counts = torch.sum(multi_label > 0, dim=(1, 2, 3)) + vprint("found this amount of blobs in batch elements:", blob_counts) + + for ula in unique_labels: + if ula == 0: + vprint("ula is 0 we do nothing") + continue + + vprint("ula greater than 0:", ula) + + # Create masks for the entire batch + label_mask = (multi_label > 0) + label_mask = ~label_mask + label_mask[multi_label == ula] = 1 + vprint("label_mask", label_mask) + + the_label = (multi_label == ula) + the_label_int = the_label.int() + vprint("the_label:", torch.count_nonzero(the_label)) + + masked_output = network_outputs * label_mask + + try: + # Try with int labels first + blob_loss = criterion(masked_output, the_label_int) + except: + # If int doesn't work, try float + blob_loss = criterion(masked_output, the_label.float()) + vprint("blob_loss:", blob_loss) + + # Compute loss for each element in the batch + element_losses = blob_loss.view(batch_length, -1).mean(dim=1) + element_blob_loss.append(element_losses) + + # Stack losses for all labels + if element_blob_loss: + all_losses = torch.stack(element_blob_loss, dim=1) + + # Compute mean loss for each element, ignoring zeros + mask = all_losses != 0 + element_mean_losses = torch.sum(all_losses * mask, dim=1) / torch.sum(mask, dim=1).clamp(min=1) + + # Compute overall mean loss + mean_element_blob_loss = element_mean_losses.mean() + else: + mean_element_blob_loss = torch.tensor(0.0, device=network_outputs.device) vprint("mean_element_blob_loss", mean_element_blob_loss) return mean_element_blob_loss + def compute_no_masking_multi( criterion, network_outputs: torch.Tensor,