@@ -1265,109 +1265,6 @@ def encode_prompt(
12651265
12661266 return prompt_embeds , pooled_prompt_embeds , text_ids
12671267
1268-
1269- # CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer:
1270- # https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95
1271- class CustomFlowMatchEulerDiscreteScheduler (FlowMatchEulerDiscreteScheduler ):
1272- def __init__ (self , * args , ** kwargs ):
1273- super ().__init__ (* args , ** kwargs )
1274-
1275- with torch .no_grad ():
1276- # create weights for timesteps
1277- num_timesteps = 1000
1278-
1279- # generate the multiplier based on cosmap loss weighing
1280- # this is only used on linear timesteps for now
1281-
1282- # cosine map weighing is higher in the middle and lower at the ends
1283- # bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2
1284- # cosmap_weighing = 2 / (math.pi * bot)
1285-
1286- # sigma sqrt weighing is significantly higher at the end and lower at the beginning
1287- sigma_sqrt_weighing = (self .sigmas ** - 2.0 ).float ()
1288- # clip at 1e4 (1e6 is too high)
1289- sigma_sqrt_weighing = torch .clamp (sigma_sqrt_weighing , max = 1e4 )
1290- # bring to a mean of 1
1291- sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing .mean ()
1292-
1293- # Create linear timesteps from 1000 to 0
1294- timesteps = torch .linspace (1000 , 0 , num_timesteps , device = "cpu" )
1295-
1296- self .linear_timesteps = timesteps
1297- # self.linear_timesteps_weights = cosmap_weighing
1298- self .linear_timesteps_weights = sigma_sqrt_weighing
1299-
1300- # self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu')
1301- pass
1302-
1303- def get_weights_for_timesteps (self , timesteps : torch .Tensor ) -> torch .Tensor :
1304- # Get the indices of the timesteps
1305- step_indices = [(self .timesteps == t ).nonzero ().item () for t in timesteps ]
1306-
1307- # Get the weights for the timesteps
1308- weights = self .linear_timesteps_weights [step_indices ].flatten ()
1309-
1310- return weights
1311-
1312- def get_sigmas (self , timesteps : torch .Tensor , n_dim , dtype , device ) -> torch .Tensor :
1313- sigmas = self .sigmas .to (device = device , dtype = dtype )
1314- schedule_timesteps = self .timesteps .to (device )
1315- timesteps = timesteps .to (device )
1316- step_indices = [(schedule_timesteps == t ).nonzero ().item () for t in timesteps ]
1317-
1318- sigma = sigmas [step_indices ].flatten ()
1319- while len (sigma .shape ) < n_dim :
1320- sigma = sigma .unsqueeze (- 1 )
1321-
1322- return sigma
1323-
1324- def add_noise (
1325- self ,
1326- original_samples : torch .Tensor ,
1327- noise : torch .Tensor ,
1328- timesteps : torch .Tensor ,
1329- ) -> torch .Tensor :
1330- ## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
1331- ## Add noise according to flow matching.
1332- ## zt = (1 - texp) * x + texp * z1
1333-
1334- # sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
1335- # noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
1336-
1337- # timestep needs to be in [0, 1], we store them in [0, 1000]
1338- # noisy_sample = (1 - timestep) * latent + timestep * noise
1339- t_01 = (timesteps / 1000 ).to (original_samples .device )
1340- noisy_model_input = (1 - t_01 ) * original_samples + t_01 * noise
1341-
1342- # n_dim = original_samples.ndim
1343- # sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
1344- # noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
1345- return noisy_model_input
1346-
1347- def scale_model_input (self , sample : torch .Tensor , timestep : Union [float , torch .Tensor ]) -> torch .Tensor :
1348- return sample
1349-
1350- def set_train_timesteps (self , num_timesteps , device , linear = False ):
1351- if linear :
1352- timesteps = torch .linspace (1000 , 0 , num_timesteps , device = device )
1353- self .timesteps = timesteps
1354- return timesteps
1355- else :
1356- # distribute them closer to center. Inference distributes them as a bias toward first
1357- # Generate values from 0 to 1
1358- t = torch .sigmoid (torch .randn ((num_timesteps ,), device = device ))
1359-
1360- # Scale and reverse the values to go from 1000 to 0
1361- timesteps = (1 - t ) * 1000
1362-
1363- # Sort the timesteps in descending order
1364- timesteps , _ = torch .sort (timesteps , descending = True )
1365-
1366- self .timesteps = timesteps .to (device = device )
1367-
1368- return timesteps
1369-
1370-
13711268def main (args ):
13721269 if args .report_to == "wandb" and args .hub_token is not None :
13731270 raise ValueError (
@@ -1499,7 +1396,7 @@ def main(args):
14991396 )
15001397
15011398 # Load scheduler and models
1502- noise_scheduler = CustomFlowMatchEulerDiscreteScheduler .from_pretrained (
1399+ noise_scheduler = FlowMatchEulerDiscreteScheduler .from_pretrained (
15031400 args .pretrained_model_name_or_path , subfolder = "scheduler"
15041401 )
15051402 noise_scheduler_copy = copy .deepcopy (noise_scheduler )
@@ -2337,6 +2234,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
23372234 removing_checkpoint = os .path .join (args .output_dir , removing_checkpoint )
23382235 shutil .rmtree (removing_checkpoint )
23392236
2237+ # save embeddings
2238+
23402239 save_path = os .path .join (args .output_dir , f"checkpoint-{ global_step } " )
23412240 accelerator .save_state (save_path )
23422241 logger .info (f"Saved state to { save_path } " )
0 commit comments