Skip to content

Commit aeeb9ee

Browse files
committed
add the general logic for the zero shot mask weighting of loss
1 parent 7527f48 commit aeeb9ee

File tree

3 files changed

+33
-2
lines changed

3 files changed

+33
-2
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,13 @@ values = wrapped_to_values(
8080

8181
## Todo
8282

83-
- [ ] add the zero-shot masking of concept proposed in paper
8483
- [ ] offer a way to combine separately learned concepts from multiple `Rank1EditModule` into one for inference
8584
- [ ] handle rank-1 update for multiple concepts
8685
- [x] handle training with multiple concepts
8786
- [ ] handle multiple concepts in one prompt at inference - summation of the sigmoid term + outputs
8887
- [ ] offer a magic function that automatically tries to wire up the cross attention by looking for appropriately named `nn.Linear` and auto-inferring which ones are keys or values
88+
89+
- [x] add the zero-shot masking of concept proposed in paper
8990
- [x] take care of the function that takes in the dataset and text encoder and precomputes the covariance matrix needed for the rank-1 update
9091
- [x] instead of having the researcher worry about different learning rates, offer the fractional gradient trick from other paper (to learn the concept embedding)
9192

perfusion_pytorch/perfusion.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,36 @@ def return_text_enc_with_concept_and_superclass(
8181

8282
return concept_text_enc, concept_indices, superclass_text_enc
8383

84+
# loss weighted by the mask
85+
86+
@beartype
87+
def loss_fn_weighted_by_mask(
88+
pred: FloatTensor,
89+
target: FloatTensor,
90+
mask: FloatTensor,
91+
normalized_mask_min_value = 0.
92+
):
93+
assert mask.shape[-2:] == pred.shape[-2:] == target.shape[-2:]
94+
assert mask.shape[0] == pred.shape[0] == target.shape[0]
95+
96+
assert (mask.amin() >= 0.).all(), 'mask should not have values below 0'
97+
98+
if mask.ndim == 4:
99+
assert mask.shape[1] == 1
100+
mask = rearrange(mask, 'b 1 h w -> b h w')
101+
102+
loss = F.mse_loss(pred, target, reduction = 'none')
103+
loss = reduce(loss, 'b c h w -> b h w')
104+
105+
# normalize mask by max
106+
107+
normalized_mask = mask / mask.amax(dim = -1, keepdim = True).clamp(min = 1e-5)
108+
normalized_mask = normalized_mask.clamp(min = normalized_mask_min_value)
109+
110+
loss = loss * normalized_mask
111+
112+
return loss.mean()
113+
84114
# a module that wraps the keys and values projection of the cross attentions to text encodings
85115

86116
class Rank1EditModule(Module):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.21',
6+
version = '0.0.22',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)