Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 57 additions & 86 deletions blob_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ def compute_compound_loss(
weight = entry["weight"]

sigmoid = entry["sigmoid"]
if blob_loss_mode == False:
if blob_loss_mode:
vprint("computing main loss!")
if sigmoid == True:
if sigmoid:
sigmoid_network_outputs = torch.sigmoid(raw_network_outputs)
individual_loss = criterion(sigmoid_network_outputs, label)
else:
individual_loss = criterion(raw_network_outputs, label)
elif blob_loss_mode == True:
elif blob_loss_mode:
vprint("computing blob loss!")
if masked == True: # this is the default blob loss
if sigmoid == True:
if masked: # this is the default blob loss
if sigmoid:
sigmoid_network_outputs = torch.sigmoid(raw_network_outputs)
individual_loss = compute_blob_loss_multi(
criterion=criterion,
Expand All @@ -49,8 +49,8 @@ def compute_compound_loss(
network_outputs=raw_network_outputs,
multi_label=label,
)
elif masked == False: # without masking for ablation study
if sigmoid == True:
elif not masked: # without masking for ablation study
if sigmoid:
sigmoid_network_outputs = torch.sigmoid(raw_network_outputs)
individual_loss = compute_no_masking_multi(
criterion=criterion,
Expand Down Expand Up @@ -78,97 +78,68 @@ def compute_blob_loss_multi(
multi_label: torch.Tensor,
):
"""
1. loop through elements in our batch
2. loop through blobs per element compute loss and divide by blobs to have element loss
2.1 we need to account for sigmoid and non/sigmoid in conjunction with BCE
3. divide by batch length to have a correct batch loss for back prop
Compute blob loss for multi-label classification.
"""
batch_length = multi_label.shape[0]
vprint("batch_length:", batch_length)

element_blob_loss = []
# loop over elements
for element in range(batch_length):
if element < batch_length:
end_index = element + 1
elif element == batch_length:
end_index = None

element_label = multi_label[element:end_index, ...]
vprint("element label shape:", element_label.shape)

vprint("element_label:", element_label.shape)

element_output = network_outputs[element:end_index, ...]

# loop through labels
unique_labels = torch.unique(element_label)
blob_count = len(unique_labels) - 1
vprint("found this amount of blobs in batch element:", blob_count)

label_loss = []
for ula in unique_labels:
if ula == 0:
vprint("ula is 0 we do nothing")
else:
# first we need one hot labels
vprint("ula greater than 0:", ula)
label_mask = element_label > 0
# we flip labels
label_mask = ~label_mask

# we set the mask to true where our label of interest is located
# vprint(torch.count_nonzero(label_mask))
label_mask[element_label == ula] = 1
# vprint(torch.count_nonzero(label_mask))
vprint("label_mask", label_mask)
# vprint("torch.unique(label_mask):", torch.unique(label_mask))

the_label = element_label == ula
the_label_int = the_label.int()
vprint("the_label:", torch.count_nonzero(the_label))


# debugging
# masked_label = the_label * label_mask
# vprint("masked_label:", torch.count_nonzero(masked_label))

masked_output = element_output * label_mask

try:
# we try with int labels first, but some losses require floats
blob_loss = criterion(masked_output, the_label_int)
except:
# if int does not work we try float
blob_loss = criterion(masked_output, the_label.float())
vprint("blob_loss:", blob_loss)

label_loss.append(blob_loss)

# compute mean
vprint("label_loss:", label_loss)
# mean_label_loss = 0
vprint("blobs in crop:", len(label_loss))
if not len(label_loss) == 0:
mean_label_loss = sum(label_loss) / len(label_loss)
# mean_label_loss = sum(label_loss) / \
# torch.count_nonzero(label_loss)
vprint("mean_label_loss", mean_label_loss)
element_blob_loss.append(mean_label_loss)

# compute mean
vprint("element_blob_loss:", element_blob_loss)
mean_element_blob_loss = 0
vprint("elements in batch:", len(element_blob_loss))
if not len(element_blob_loss) == 0:
mean_element_blob_loss = sum(element_blob_loss) / len(element_blob_loss)
# element_blob_loss) / torch.count_nonzero(element_blob_loss)
# Process the entire batch at once
unique_labels = torch.unique(multi_label)
blob_counts = torch.sum(multi_label > 0, dim=(1, 2, 3))
vprint("found this amount of blobs in batch elements:", blob_counts)

for ula in unique_labels:
if ula == 0:
vprint("ula is 0 we do nothing")
continue

vprint("ula greater than 0:", ula)

# Create masks for the entire batch
label_mask = (multi_label > 0)
label_mask = ~label_mask
label_mask[multi_label == ula] = 1
vprint("label_mask", label_mask)

the_label = (multi_label == ula)
the_label_int = the_label.int()
vprint("the_label:", torch.count_nonzero(the_label))

masked_output = network_outputs * label_mask

try:
# Try with int labels first
blob_loss = criterion(masked_output, the_label_int)
except:
# If int doesn't work, try float
blob_loss = criterion(masked_output, the_label.float())
vprint("blob_loss:", blob_loss)

# Compute loss for each element in the batch
element_losses = blob_loss.view(batch_length, -1).mean(dim=1)
element_blob_loss.append(element_losses)

# Stack losses for all labels
if element_blob_loss:
all_losses = torch.stack(element_blob_loss, dim=1)

# Compute mean loss for each element, ignoring zeros
mask = all_losses != 0
element_mean_losses = torch.sum(all_losses * mask, dim=1) / torch.sum(mask, dim=1).clamp(min=1)

# Compute overall mean loss
mean_element_blob_loss = element_mean_losses.mean()
else:
mean_element_blob_loss = torch.tensor(0.0, device=network_outputs.device)

vprint("mean_element_blob_loss", mean_element_blob_loss)

return mean_element_blob_loss



def compute_no_masking_multi(
criterion,
network_outputs: torch.Tensor,
Expand Down