Skip to content

Commit a88ab93

Browse files
committed
add a function that can accept open clip, a bunch of prompts as List[str], and return the C covariance matrix needed
1 parent f095f57 commit a88ab93

File tree

5 files changed

+128
-4
lines changed

5 files changed

+128
-4
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ It seems they successfully applied the Rank-1 editing technique from a <a href="
1010

1111
- <a href="https://stability.ai/">StabilityAI</a> for the generous sponsorship, as well as my other sponsors out there
1212

13+
- All the maintainers at <a href="https://github.com/mlfoundations/open_clip">OpenClip</a>, for their SOTA open sourced contrastive learning text-image models
14+
1315
## Install
1416

1517
```bash

perfusion_pytorch/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from perfusion_pytorch.perfusion import (
2-
Rank1EditModule
2+
Rank1EditModule,
3+
calculate_input_covariance
34
)

perfusion_pytorch/open_clip.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from beartype import beartype
2+
from beartype.typing import List, Optional
3+
4+
import torch
5+
from torch import nn, einsum
6+
import torch.nn.functional as F
7+
8+
from einops import rearrange
9+
10+
import open_clip
11+
12+
def exists(val):
13+
return val is not None
14+
15+
def l2norm(t):
16+
return F.normalize(t, dim = -1)
17+
18+
class OpenClipAdapter(nn.Module):
19+
@beartype
20+
def __init__(
21+
self,
22+
name = 'ViT-B/32',
23+
pretrained = 'laion400m_e32',
24+
tokenizer_name = 'ViT-B-32-quickgelu',
25+
eos_id = 49407
26+
):
27+
super().__init__()
28+
29+
clip, _, preprocess = open_clip.create_model_and_transforms(name, pretrained = pretrained)
30+
tokenizer = open_clip.get_tokenizer(tokenizer_name)
31+
32+
self.clip = clip
33+
self.tokenizer = tokenizer
34+
self.eos_id = eos_id
35+
36+
# hook for getting final text representation
37+
38+
text_attention_final = self.find_layer('ln_final')
39+
self._dim_latent = text_attention_final.weight.shape[0]
40+
self.text_handle = text_attention_final.register_forward_hook(self._text_hook)
41+
42+
# normalize fn
43+
44+
self.clip_normalize = preprocess.transforms[-1]
45+
self.cleared = False
46+
47+
@property
48+
def device(self):
49+
return next(self.parameters()).device
50+
51+
def find_layer(self, layer):
52+
modules = dict([*self.clip.named_modules()])
53+
return modules.get(layer, None)
54+
55+
def clear(self):
56+
if self.cleared:
57+
return
58+
59+
self.text_handle()
60+
61+
def _text_hook(self, _, inputs, outputs):
62+
self.text_encodings = outputs
63+
64+
@property
65+
def dim_latent(self):
66+
return self._dim_latent
67+
68+
@property
69+
def max_text_len(self):
70+
return self.clip.positional_embedding.shape[0]
71+
72+
@beartype
73+
def embed_texts(
74+
self,
75+
texts: List[str]
76+
):
77+
ids = self.tokenizer(texts)
78+
ids = ids.to(self.device)
79+
ids = ids[..., :self.max_text_len]
80+
81+
is_eos_id = (ids == self.eos_id)
82+
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
83+
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
84+
text_mask = text_mask & (ids != 0)
85+
assert not self.cleared
86+
87+
text_embed = self.clip.encode_text(ids)
88+
text_encodings = self.text_encodings
89+
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
90+
return text_encodings.float(), text_mask

perfusion_pytorch/perfusion.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,52 @@
1+
from math import ceil
12
from beartype import beartype
2-
from beartype.typing import Union
3+
from beartype.typing import Union, List, Optional
34

45
import torch
5-
from torch import nn, einsum, Tensor, IntTensor, LongTensor, FloatTensor, Optional
6+
from torch import nn, einsum, Tensor, IntTensor, LongTensor, FloatTensor
67
from torch.nn import Module
78
import torch.nn.functional as F
89

910
from einops import rearrange
1011

1112
from opt_einsum import contract as opt_einsum
1213

14+
from perfusion_pytorch.open_clip import OpenClipAdapter
15+
1316
# helpers
1417

1518
def exists(val):
1619
return val is not None
1720

1821
IndicesTensor = Union[LongTensor, IntTensor]
1922

23+
# function for calculating C - input covariance
24+
25+
@beartype
26+
@torch.no_grad()
27+
def calculate_input_covariance(
28+
open_clip: OpenClipAdapter,
29+
texts: List[str],
30+
batch_size = 32,
31+
**cov_kwargs
32+
):
33+
embeds, mask = open_clip.embed_texts(texts)
34+
35+
num_batches = ceil(len(texts) / batch_size)
36+
37+
all_embeds = []
38+
39+
for batch_ind in range(num_batches):
40+
start_index = batch_ind * batch_size
41+
batch_texts = texts[start_index:(start_index + batch_size)]
42+
43+
embeds, mask = open_clip.embed_texts(batch_texts)
44+
all_embeds.append(embeds[mask])
45+
46+
all_embeds = torch.cat((all_embeds), dim = 0)
47+
all_embeds = rearrange(all_embeds, 'n d -> d n')
48+
return torch.cov(all_embeds, **cov_kwargs)
49+
2050
# a module that wraps the keys and values projection of the cross attentions to text encodings
2151

2252
class Rank1EditModule(Module):

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.9',
6+
version = '0.0.10',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',
@@ -19,6 +19,7 @@
1919
install_requires=[
2020
'beartype',
2121
'einops>=0.6.1',
22+
'open-clip-torch>=2.0.0,<3.0.0',
2223
'opt-einsum',
2324
'torch>=2.0'
2425
],

0 commit comments

Comments
 (0)