@@ -12,39 +12,43 @@ class VQ(nn.Module):
1212 Quantization layer from *Neural Discrete Representation Learning*
1313
1414 Args:
15- latent_dim (int): number of features along which to quantize
16- num_tokens (int): number of tokens in the codebook
15+ embedding_dim (int): number of features along which to quantize
16+ num_embeddings (int): number of tokens in the codebook
1717 dim (int): dimension along which to quantize
1818 return_indices (bool): whether to return the indices of the quantized
1919 code points
2020 """
21+
2122 embedding : nn .Embedding
2223 dim : int
2324 commitment : float
2425 initialized : torch .Tensor
2526 return_indices : bool
2627 init_mode : str
2728
28- def __init__ (self ,
29- latent_dim : int ,
30- num_tokens : int ,
31- dim : int = 1 ,
32- commitment : float = 0.25 ,
33- init_mode : str = 'normal' ,
34- space = "l2" ,
35- return_indices : bool = True ,
36- max_age : int = 1000 ):
29+ def __init__ (
30+ self ,
31+ num_embeddings : int ,
32+ embedding_dim : int ,
33+ * ,
34+ dim : int = 1 ,
35+ commitment : float = 0.25 ,
36+ init_mode : str = "normal" ,
37+ space = "l2" ,
38+ return_indices : bool = True ,
39+ max_age : int = 1000 ,
40+ ):
3741 super (VQ , self ).__init__ ()
38- self .latent_dim = latent_dim
39- self .embedding = nn .Embedding (num_tokens , latent_dim )
42+ self .embedding_dim = embedding_dim
43+ self .embedding = nn .Embedding (num_embeddings , embedding_dim )
4044 nn .init .normal_ (self .embedding .weight , 0 , 1.1 )
4145 self .dim = dim
4246 self .commitment = commitment
43- self .register_buffer (' initialized' , torch .Tensor ([0 ]))
47+ self .register_buffer (" initialized" , torch .Tensor ([0 ]))
4448 self .return_indices = return_indices
45- assert init_mode in [' normal' , ' first' ]
49+ assert init_mode in [" normal" , " first" ]
4650 self .init_mode = init_mode
47- self .register_buffer ( ' age' , torch .empty (num_tokens ).fill_ (max_age ))
51+ self .age = nn . Buffer ( torch .empty (num_embeddings ).fill_ (max_age ))
4852 self .max_age = max_age
4953 self .space = space
5054 assert space in ["l2" , "angular" ]
@@ -66,12 +70,13 @@ def resample_dead(self, x):
6670 if len (dead ) == 0 :
6771 return
6872
69- print (f' { len (dead )} dead codes resampled' )
73+ print (f" { len (dead )} dead codes resampled" )
7074 x_flat = x .view (- 1 , x .shape [- 1 ])
7175 emb_weight = self .embedding .weight .data
72- emb_weight [dead [:len (x_flat )]] = x_flat [torch .randperm (
73- len (x_flat ))[:len (dead )]].to (emb_weight .dtype )
74- self .age [dead [:len (x_flat )]] = 0
76+ emb_weight [dead [: len (x_flat )]] = x_flat [
77+ torch .randperm (len (x_flat ))[: len (dead )]
78+ ].to (emb_weight .dtype )
79+ self .age [dead [: len (x_flat )]] = 0
7580
7681 if torch .distributed .is_initialized ():
7782 torch .distributed .broadcast (emb_weight , 0 )
@@ -94,11 +99,10 @@ def forward(
9499 else :
95100 return self .lookup (x )
96101
97- def lookup (
98- self , x : torch .Tensor
99- ) -> torch .Tensor :
102+ def lookup (self , x : torch .Tensor ) -> torch .Tensor :
103+ # x: (..., K)
100104 dim = self .dim
101- needs_transpose = dim != - 1 or dim != x .dim () - 1
105+ needs_transpose = dim not in ( - 1 , x .dim () - 1 )
102106
103107 x = self .embedding (x )
104108 if self .space == "angular" :
@@ -109,6 +113,7 @@ def lookup(
109113 dims .insert (dim , dims [- 1 ])
110114 dims .pop ()
111115 x = x .permute (* dims )
116+ # x: (..., D)
112117 return x
113118
114119 def quantize (
@@ -118,17 +123,16 @@ def quantize(
118123 nb_codes = self .embedding .weight .shape [0 ]
119124
120125 codebook = self .embedding .weight
121- if (self .init_mode == 'first' and self .initialized .item () == 0
122- and self .training ):
126+ if self .init_mode == "first" and self .initialized .item () == 0 and self .training :
123127 n_proto = self .embedding .weight .shape [0 ]
124128
125129 ch_first = x .transpose (dim , - 1 ).contiguous ().view (- 1 , x .shape [dim ])
126130 n_samples = ch_first .shape [0 ]
127- idx = torch .randint (0 , n_samples , (n_proto , ))[:nb_codes ]
131+ idx = torch .randint (0 , n_samples , (n_proto ,))[:nb_codes ]
128132 self .embedding .weight .data .copy_ (ch_first [idx ])
129133 self .initialized [:] = 1
130134
131- needs_transpose = dim != - 1 or dim != x .dim () - 1
135+ needs_transpose = dim not in ( - 1 , x .dim () - 1 )
132136 if needs_transpose :
133137 x = x .transpose (- 1 , dim ).contiguous ()
134138
@@ -139,7 +143,8 @@ def quantize(
139143 codebook = F .normalize (codebook , dim = 1 )
140144 x = F .normalize (x , dim = - 1 )
141145
142- codes , indices = quantize (x , codebook , self .commitment , - 1 )
146+ # x: (..., D)
147+ codes , indices = quantize (x , codebook , self .commitment )
143148
144149 if self .training :
145150 self .update_usage (indices )
@@ -160,39 +165,47 @@ class MultiVQ(nn.Module):
160165 Learning*
161166
162167 Args:
163- latent_dim (int): number of features along which to quantize
164- num_tokens (int): number of tokens in the codebook
168+ embedding_dim (int): number of features along which to quantize
169+ num_embeddings (int): number of tokens in the codebook
165170 num_codebooks (int): number of parallel codebooks
166171 dim (int): dimension along which to quantize
167172 an angular distance
168173 return_indices (bool): whether to return the indices of the quantized
169174 code points
170175 """
171176
172- def __init__ (self ,
173- latent_dim : int ,
174- num_tokens : int ,
175- num_codebooks : int ,
176- dim : int = 1 ,
177- commitment : float = 0.25 ,
178- init_mode : str = 'normal' ,
179- return_indices : bool = True ,
180- max_age : int = 1000 ):
181- assert latent_dim % num_codebooks == 0 , (
182- "num_codebooks must divide evenly latent_dim" )
177+ def __init__ (
178+ self ,
179+ embedding_dim : int ,
180+ num_embeddings : int ,
181+ num_codebooks : int ,
182+ dim : int = 1 ,
183+ commitment : float = 0.25 ,
184+ init_mode : str = "normal" ,
185+ return_indices : bool = True ,
186+ max_age : int = 1000 ,
187+ ):
188+ assert (
189+ embedding_dim % num_codebooks == 0
190+ ), "num_codebooks must divide evenly embedding_dim"
183191 super (MultiVQ , self ).__init__ ()
184192 self .dim = dim
185193 self .num_codebooks = num_codebooks
186194 self .return_indices = return_indices
187- self .vqs = nn .ModuleList ([
188- VQ (latent_dim // num_codebooks ,
189- num_tokens ,
190- dim = dim ,
191- commitment = commitment ,
192- init_mode = init_mode ,
193- return_indices = return_indices ,
194- max_age = max_age ) for _ in range (num_codebooks )
195- ])
195+ self .vqs = nn .ModuleList (
196+ [
197+ VQ (
198+ embedding_dim // num_codebooks ,
199+ num_embeddings ,
200+ dim = dim ,
201+ commitment = commitment ,
202+ init_mode = init_mode ,
203+ return_indices = return_indices ,
204+ max_age = max_age ,
205+ )
206+ for _ in range (num_codebooks )
207+ ]
208+ )
196209
197210 def forward (
198211 self , x : torch .Tensor
@@ -206,13 +219,79 @@ def forward(
206219 return torch .cat (quantized , dim = self .dim )
207220
208221
209- class MultiVQ2 (VQ ):
222+ class RVQ (nn .Module ):
223+ def __init__ (
224+ self ,
225+ num_embeddings : int ,
226+ embedding_dim : int ,
227+ num_codebooks : int ,
228+ * ,
229+ dim : int = 1 ,
230+ commitment : float = 0.25 ,
231+ init_mode : str = "normal" ,
232+ return_indices : bool = True ,
233+ max_age : int = 1000 ,
234+ ):
235+ super ().__init__ ()
236+ self .dim = dim
237+ self .return_indices = return_indices
238+ self .codebooks = nn .ModuleList (
239+ [
240+ VQ (
241+ num_embeddings ,
242+ embedding_dim ,
243+ dim = - 1 ,
244+ commitment = commitment ,
245+ init_mode = init_mode ,
246+ return_indices = True ,
247+ max_age = max_age ,
248+ )
249+ for _ in range (num_codebooks )
250+ ]
251+ )
252+
253+ def forward (
254+ self , x : torch .Tensor
255+ ) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
256+ dim = self .dim
257+ needs_transpose = dim not in (- 1 , x .dim () - 1 )
258+ if needs_transpose :
259+ x = x .transpose (- 1 , dim ).contiguous ()
260+
261+ out = torch .zeros_like (x )
262+ indices = []
263+ for i , cb in enumerate (self .codebooks ):
264+ this_codes , this_indices = cb (x - out )
265+ out += this_codes
266+ print ("residual" , torch .norm (x - out ).item ())
267+ indices .append (this_indices )
268+
269+ indices = torch .cat (indices , dim = - 1 )
210270
211- def forward (self , x : torch .Tensor ) -> torch .Tensor :
212- d = self .latent_dim
213- dims = x .shape
214- batched_dims = list (dims )
215- batched_dims [self .dim ] = d
216- batched_dims [self .dim - 1 ] = - 1
217- out = super (MultiVQ2 , self ).forward (x .view (* batched_dims ))
218- return out .view (* dims ).contiguous ()
271+ if needs_transpose :
272+ out = out .transpose (- 1 , dim ).contiguous ()
273+ indices = indices .transpose (- 1 , dim ).contiguous ()
274+
275+ if self .return_indices :
276+ return out , indices
277+ else :
278+ return out
279+
280+ def lookup (self , x : torch .Tensor ) -> torch .Tensor :
281+ # x: (..., K)
282+ dim = self .dim
283+ needs_transpose = dim not in (- 1 , x .dim () - 1 )
284+
285+ x = torch .stack (
286+ [cb .lookup (xx ) for cb , xx in zip (self .codebooks , x .split (1 , dim = - 1 ))],
287+ dim = - 1 ,
288+ )
289+ x = x .sum (- 1 )
290+
291+ if needs_transpose :
292+ dims = list (range (x .ndim ))
293+ dims .insert (dim , dims [- 1 ])
294+ dims .pop ()
295+ x = x .permute (* dims )
296+ # x: (..., D)
297+ return x
0 commit comments