1+ import random
12from typing import Any , Dict , List
23
34import numpy as np
@@ -20,7 +21,7 @@ def pad(t: torch.Tensor, padding_length: int) -> torch.Tensor:
2021 return t [:padding_length ], torch .ones (padding_length )
2122
2223
23- def get_torch_tensors_from_row_dict (row_dict , keys ) -> Dict [str , Any ]:
24+ def get_torch_tensors_from_row_dict (row_dict , keys , cfg_rate ) -> Dict [str , Any ]:
2425 """
2526 Get the latents and prompts from a row dictionary.
2627 """
@@ -42,7 +43,10 @@ def get_torch_tensors_from_row_dict(row_dict, keys) -> Dict[str, Any]:
4243 bytes = row_dict [f"{ key } _bytes" ]
4344
4445 # TODO (peiyuan): read precision
45- data = np .frombuffer (bytes , dtype = np .float32 ).reshape (shape ).copy ()
46+ if key == 'text_embedding' and random .random () < cfg_rate :
47+ data = np .zeros ((512 , 4096 ), dtype = np .float32 )
48+ else :
49+ data = np .frombuffer (bytes , dtype = np .float32 ).reshape (shape ).copy ()
4650 data = torch .from_numpy (data )
4751 if len (data .shape ) == 3 :
4852 B , L , D = data .shape
@@ -53,8 +57,11 @@ def get_torch_tensors_from_row_dict(row_dict, keys) -> Dict[str, Any]:
5357
5458
5559def collate_latents_embs_masks (
56- batch_to_process , text_padding_length ,
57- keys ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , List [str ]]:
60+ batch_to_process ,
61+ text_padding_length ,
62+ keys ,
63+ cfg_rate = 0.0
64+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , List [str ]]:
5865 # Initialize tensors to hold padded embeddings and masks
5966 all_latents = []
6067 all_embs = []
@@ -63,7 +70,7 @@ def collate_latents_embs_masks(
6370 # Process each row individually
6471 for i , row in enumerate (batch_to_process ):
6572 # Get tensors from row
66- data = get_torch_tensors_from_row_dict (row , keys )
73+ data = get_torch_tensors_from_row_dict (row , keys , cfg_rate )
6774 latents , emb = data ["vae_latent" ], data ["text_embedding" ]
6875
6976 padded_emb , mask = pad (emb , text_padding_length )
0 commit comments