@@ -203,9 +203,9 @@ def __init__(self, config: VQBeTConfig):
203203 def forward (self , input , targets = None ):
204204 device = input .device
205205 b , t , d = input .size ()
206- assert (
207- t <= self .config .gpt_block_size
208- ), f"Cannot forward sequence of length { t } , block size is only { self . config . gpt_block_size } "
206+ assert t <= self . config . gpt_block_size , (
207+ f"Cannot forward sequence of length { t } , block size is only { self .config .gpt_block_size } "
208+ )
209209
210210 # positional encodings that are added to the input embeddings
211211 pos = torch .arange (0 , t , dtype = torch .long , device = device ).unsqueeze (0 ) # shape (1, t)
@@ -273,10 +273,10 @@ def configure_parameters(self):
273273 assert len (inter_params ) == 0 , "parameters {} made it into both decay/no_decay sets!" .format (
274274 str (inter_params )
275275 )
276- assert (
277- len ( param_dict . keys () - union_params ) == 0
278- ), "parameters {} were not separated into either decay/no_decay set!" . format (
279- str ( param_dict . keys () - union_params ),
276+ assert len ( param_dict . keys () - union_params ) == 0 , (
277+ "parameters {} were not separated into either decay/no_decay set!" . format (
278+ str ( param_dict . keys () - union_params ),
279+ )
280280 )
281281
282282 decay = [param_dict [pn ] for pn in sorted (decay )]
@@ -419,9 +419,9 @@ def get_codebook_vector_from_indices(self, indices):
419419 # and the network should be able to reconstruct
420420
421421 if quantize_dim < self .num_quantizers :
422- assert (
423- self . quantize_dropout > 0.0
424- ), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
422+ assert self . quantize_dropout > 0.0 , (
423+ "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
424+ )
425425 indices = F .pad (indices , (0 , self .num_quantizers - quantize_dim ), value = - 1 )
426426
427427 # get ready for gathering
@@ -472,9 +472,9 @@ def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=
472472 all_indices = []
473473
474474 if return_loss :
475- assert not torch .any (
476- indices == - 1
477- ), "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
475+ assert not torch .any (indices == - 1 ), (
476+ "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
477+ )
478478 ce_losses = []
479479
480480 should_quantize_dropout = self .training and self .quantize_dropout and not return_loss
@@ -887,9 +887,9 @@ def calculate_ce_loss(codes):
887887 # only calculate orthogonal loss for the activated codes for this batch
888888
889889 if self .orthogonal_reg_active_codes_only :
890- assert not (
891- is_multiheaded and self . separate_codebook_per_head
892- ), "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet"
890+ assert not (is_multiheaded and self . separate_codebook_per_head ), (
891+ "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet"
892+ )
893893 unique_code_ids = torch .unique (embed_ind )
894894 codebook = codebook [:, unique_code_ids ]
895895
@@ -999,9 +999,9 @@ def gumbel_sample(
999999 ind = sampling_logits .argmax (dim = dim )
10001000 one_hot = F .one_hot (ind , size ).type (dtype )
10011001
1002- assert not (
1003- reinmax and not straight_through
1004- ), "reinmax can only be turned on if using straight through gumbel softmax"
1002+ assert not (reinmax and not straight_through ), (
1003+ " reinmax can only be turned on if using straight through gumbel softmax"
1004+ )
10051005
10061006 if not straight_through or temperature <= 0.0 or not training :
10071007 return ind , one_hot
@@ -1209,9 +1209,9 @@ def __init__(
12091209 self .gumbel_sample = gumbel_sample
12101210 self .sample_codebook_temp = sample_codebook_temp
12111211
1212- assert not (
1213- use_ddp and num_codebooks > 1 and kmeans_init
1214- ), "kmeans init is not compatible with multiple codebooks in distributed environment for now"
1212+ assert not (use_ddp and num_codebooks > 1 and kmeans_init ), (
1213+ "kmeans init is not compatible with multiple codebooks in distributed environment for now"
1214+ )
12151215
12161216 self .sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
12171217 self .kmeans_all_reduce_fn = distributed .all_reduce if use_ddp and sync_kmeans else noop
0 commit comments