Skip to content

Commit 2b0d45f

Browse files
committed
towards rvq
1 parent 2dd2720 commit 2b0d45f

File tree

2 files changed

+140
-61
lines changed

2 files changed

+140
-61
lines changed

torchelie/nn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .conv import *
33
from .debug import Debug, Dummy
44
from .noise import Noise
5-
from .vq import VQ, MultiVQ, MultiVQ2
5+
from .vq import VQ, MultiVQ, RVQ
66
from .imagenetinputnorm import ImageNetInputNorm
77
from .withsavedactivations import WithSavedActivations
88
from .maskedconv import MaskedConv2d, TopLeftConv2d

torchelie/nn/vq.py

Lines changed: 139 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -12,39 +12,43 @@ class VQ(nn.Module):
1212
Quantization layer from *Neural Discrete Representation Learning*
1313
1414
Args:
15-
latent_dim (int): number of features along which to quantize
16-
num_tokens (int): number of tokens in the codebook
15+
embedding_dim (int): number of features along which to quantize
16+
num_embeddings (int): number of tokens in the codebook
1717
dim (int): dimension along which to quantize
1818
return_indices (bool): whether to return the indices of the quantized
1919
code points
2020
"""
21+
2122
embedding: nn.Embedding
2223
dim: int
2324
commitment: float
2425
initialized: torch.Tensor
2526
return_indices: bool
2627
init_mode: str
2728

28-
def __init__(self,
29-
latent_dim: int,
30-
num_tokens: int,
31-
dim: int = 1,
32-
commitment: float = 0.25,
33-
init_mode: str = 'normal',
34-
space="l2",
35-
return_indices: bool = True,
36-
max_age: int = 1000):
29+
def __init__(
30+
self,
31+
num_embeddings: int,
32+
embedding_dim: int,
33+
*,
34+
dim: int = 1,
35+
commitment: float = 0.25,
36+
init_mode: str = "normal",
37+
space="l2",
38+
return_indices: bool = True,
39+
max_age: int = 1000,
40+
):
3741
super(VQ, self).__init__()
38-
self.latent_dim = latent_dim
39-
self.embedding = nn.Embedding(num_tokens, latent_dim)
42+
self.embedding_dim = embedding_dim
43+
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
4044
nn.init.normal_(self.embedding.weight, 0, 1.1)
4145
self.dim = dim
4246
self.commitment = commitment
43-
self.register_buffer('initialized', torch.Tensor([0]))
47+
self.register_buffer("initialized", torch.Tensor([0]))
4448
self.return_indices = return_indices
45-
assert init_mode in ['normal', 'first']
49+
assert init_mode in ["normal", "first"]
4650
self.init_mode = init_mode
47-
self.register_buffer('age', torch.empty(num_tokens).fill_(max_age))
51+
self.age = nn.Buffer(torch.empty(num_embeddings).fill_(max_age))
4852
self.max_age = max_age
4953
self.space = space
5054
assert space in ["l2", "angular"]
@@ -66,12 +70,13 @@ def resample_dead(self, x):
6670
if len(dead) == 0:
6771
return
6872

69-
print(f'{len(dead)} dead codes resampled')
73+
print(f"{len(dead)} dead codes resampled")
7074
x_flat = x.view(-1, x.shape[-1])
7175
emb_weight = self.embedding.weight.data
72-
emb_weight[dead[:len(x_flat)]] = x_flat[torch.randperm(
73-
len(x_flat))[:len(dead)]].to(emb_weight.dtype)
74-
self.age[dead[:len(x_flat)]] = 0
76+
emb_weight[dead[: len(x_flat)]] = x_flat[
77+
torch.randperm(len(x_flat))[: len(dead)]
78+
].to(emb_weight.dtype)
79+
self.age[dead[: len(x_flat)]] = 0
7580

7681
if torch.distributed.is_initialized():
7782
torch.distributed.broadcast(emb_weight, 0)
@@ -94,11 +99,10 @@ def forward(
9499
else:
95100
return self.lookup(x)
96101

97-
def lookup(
98-
self, x: torch.Tensor
99-
) -> torch.Tensor:
102+
def lookup(self, x: torch.Tensor) -> torch.Tensor:
103+
# x: (..., K)
100104
dim = self.dim
101-
needs_transpose = dim != -1 or dim != x.dim() - 1
105+
needs_transpose = dim not in (-1, x.dim() - 1)
102106

103107
x = self.embedding(x)
104108
if self.space == "angular":
@@ -109,6 +113,7 @@ def lookup(
109113
dims.insert(dim, dims[-1])
110114
dims.pop()
111115
x = x.permute(*dims)
116+
# x: (..., D)
112117
return x
113118

114119
def quantize(
@@ -118,17 +123,16 @@ def quantize(
118123
nb_codes = self.embedding.weight.shape[0]
119124

120125
codebook = self.embedding.weight
121-
if (self.init_mode == 'first' and self.initialized.item() == 0
122-
and self.training):
126+
if self.init_mode == "first" and self.initialized.item() == 0 and self.training:
123127
n_proto = self.embedding.weight.shape[0]
124128

125129
ch_first = x.transpose(dim, -1).contiguous().view(-1, x.shape[dim])
126130
n_samples = ch_first.shape[0]
127-
idx = torch.randint(0, n_samples, (n_proto, ))[:nb_codes]
131+
idx = torch.randint(0, n_samples, (n_proto,))[:nb_codes]
128132
self.embedding.weight.data.copy_(ch_first[idx])
129133
self.initialized[:] = 1
130134

131-
needs_transpose = dim != -1 or dim != x.dim() - 1
135+
needs_transpose = dim not in (-1, x.dim() - 1)
132136
if needs_transpose:
133137
x = x.transpose(-1, dim).contiguous()
134138

@@ -139,7 +143,8 @@ def quantize(
139143
codebook = F.normalize(codebook, dim=1)
140144
x = F.normalize(x, dim=-1)
141145

142-
codes, indices = quantize(x, codebook, self.commitment, -1)
146+
# x: (..., D)
147+
codes, indices = quantize(x, codebook, self.commitment)
143148

144149
if self.training:
145150
self.update_usage(indices)
@@ -160,39 +165,47 @@ class MultiVQ(nn.Module):
160165
Learning*
161166
162167
Args:
163-
latent_dim (int): number of features along which to quantize
164-
num_tokens (int): number of tokens in the codebook
168+
embedding_dim (int): number of features along which to quantize
169+
num_embeddings (int): number of tokens in the codebook
165170
num_codebooks (int): number of parallel codebooks
166171
dim (int): dimension along which to quantize
167172
an angular distance
168173
return_indices (bool): whether to return the indices of the quantized
169174
code points
170175
"""
171176

172-
def __init__(self,
173-
latent_dim: int,
174-
num_tokens: int,
175-
num_codebooks: int,
176-
dim: int = 1,
177-
commitment: float = 0.25,
178-
init_mode: str = 'normal',
179-
return_indices: bool = True,
180-
max_age: int = 1000):
181-
assert latent_dim % num_codebooks == 0, (
182-
"num_codebooks must divide evenly latent_dim")
177+
def __init__(
178+
self,
179+
embedding_dim: int,
180+
num_embeddings: int,
181+
num_codebooks: int,
182+
dim: int = 1,
183+
commitment: float = 0.25,
184+
init_mode: str = "normal",
185+
return_indices: bool = True,
186+
max_age: int = 1000,
187+
):
188+
assert (
189+
embedding_dim % num_codebooks == 0
190+
), "num_codebooks must divide evenly embedding_dim"
183191
super(MultiVQ, self).__init__()
184192
self.dim = dim
185193
self.num_codebooks = num_codebooks
186194
self.return_indices = return_indices
187-
self.vqs = nn.ModuleList([
188-
VQ(latent_dim // num_codebooks,
189-
num_tokens,
190-
dim=dim,
191-
commitment=commitment,
192-
init_mode=init_mode,
193-
return_indices=return_indices,
194-
max_age=max_age) for _ in range(num_codebooks)
195-
])
195+
self.vqs = nn.ModuleList(
196+
[
197+
VQ(
198+
embedding_dim // num_codebooks,
199+
num_embeddings,
200+
dim=dim,
201+
commitment=commitment,
202+
init_mode=init_mode,
203+
return_indices=return_indices,
204+
max_age=max_age,
205+
)
206+
for _ in range(num_codebooks)
207+
]
208+
)
196209

197210
def forward(
198211
self, x: torch.Tensor
@@ -206,13 +219,79 @@ def forward(
206219
return torch.cat(quantized, dim=self.dim)
207220

208221

209-
class MultiVQ2(VQ):
222+
class RVQ(nn.Module):
223+
def __init__(
224+
self,
225+
num_embeddings: int,
226+
embedding_dim: int,
227+
num_codebooks: int,
228+
*,
229+
dim: int = 1,
230+
commitment: float = 0.25,
231+
init_mode: str = "normal",
232+
return_indices: bool = True,
233+
max_age: int = 1000,
234+
):
235+
super().__init__()
236+
self.dim = dim
237+
self.return_indices = return_indices
238+
self.codebooks = nn.ModuleList(
239+
[
240+
VQ(
241+
num_embeddings,
242+
embedding_dim,
243+
dim=-1,
244+
commitment=commitment,
245+
init_mode=init_mode,
246+
return_indices=True,
247+
max_age=max_age,
248+
)
249+
for _ in range(num_codebooks)
250+
]
251+
)
252+
253+
def forward(
254+
self, x: torch.Tensor
255+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
256+
dim = self.dim
257+
needs_transpose = dim not in (-1, x.dim() - 1)
258+
if needs_transpose:
259+
x = x.transpose(-1, dim).contiguous()
260+
261+
out = torch.zeros_like(x)
262+
indices = []
263+
for i, cb in enumerate(self.codebooks):
264+
this_codes, this_indices = cb(x - out)
265+
out += this_codes
266+
print("residual", torch.norm(x - out).item())
267+
indices.append(this_indices)
268+
269+
indices = torch.cat(indices, dim=-1)
210270

211-
def forward(self, x: torch.Tensor) -> torch.Tensor:
212-
d = self.latent_dim
213-
dims = x.shape
214-
batched_dims = list(dims)
215-
batched_dims[self.dim] = d
216-
batched_dims[self.dim - 1] = -1
217-
out = super(MultiVQ2, self).forward(x.view(*batched_dims))
218-
return out.view(*dims).contiguous()
271+
if needs_transpose:
272+
out = out.transpose(-1, dim).contiguous()
273+
indices = indices.transpose(-1, dim).contiguous()
274+
275+
if self.return_indices:
276+
return out, indices
277+
else:
278+
return out
279+
280+
def lookup(self, x: torch.Tensor) -> torch.Tensor:
281+
# x: (..., K)
282+
dim = self.dim
283+
needs_transpose = dim not in (-1, x.dim() - 1)
284+
285+
x = torch.stack(
286+
[cb.lookup(xx) for cb, xx in zip(self.codebooks, x.split(1, dim=-1))],
287+
dim=-1,
288+
)
289+
x = x.sum(-1)
290+
291+
if needs_transpose:
292+
dims = list(range(x.ndim))
293+
dims.insert(dim, dims[-1])
294+
dims.pop()
295+
x = x.permute(*dims)
296+
# x: (..., D)
297+
return x

0 commit comments

Comments
 (0)