Skip to content

Commit 298709e

Browse files
committed
remove custom scheduler
1 parent 3be6706 commit 298709e

File tree

1 file changed

+3
-104
lines changed

1 file changed

+3
-104
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 3 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,109 +1265,6 @@ def encode_prompt(
12651265

12661266
return prompt_embeds, pooled_prompt_embeds, text_ids
12671267

1268-
1269-
# CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer:
1270-
# https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95
1271-
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
1272-
def __init__(self, *args, **kwargs):
1273-
super().__init__(*args, **kwargs)
1274-
1275-
with torch.no_grad():
1276-
# create weights for timesteps
1277-
num_timesteps = 1000
1278-
1279-
# generate the multiplier based on cosmap loss weighing
1280-
# this is only used on linear timesteps for now
1281-
1282-
# cosine map weighing is higher in the middle and lower at the ends
1283-
# bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2
1284-
# cosmap_weighing = 2 / (math.pi * bot)
1285-
1286-
# sigma sqrt weighing is significantly higher at the end and lower at the beginning
1287-
sigma_sqrt_weighing = (self.sigmas**-2.0).float()
1288-
# clip at 1e4 (1e6 is too high)
1289-
sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4)
1290-
# bring to a mean of 1
1291-
sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean()
1292-
1293-
# Create linear timesteps from 1000 to 0
1294-
timesteps = torch.linspace(1000, 0, num_timesteps, device="cpu")
1295-
1296-
self.linear_timesteps = timesteps
1297-
# self.linear_timesteps_weights = cosmap_weighing
1298-
self.linear_timesteps_weights = sigma_sqrt_weighing
1299-
1300-
# self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu')
1301-
pass
1302-
1303-
def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor:
1304-
# Get the indices of the timesteps
1305-
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
1306-
1307-
# Get the weights for the timesteps
1308-
weights = self.linear_timesteps_weights[step_indices].flatten()
1309-
1310-
return weights
1311-
1312-
def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor:
1313-
sigmas = self.sigmas.to(device=device, dtype=dtype)
1314-
schedule_timesteps = self.timesteps.to(device)
1315-
timesteps = timesteps.to(device)
1316-
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
1317-
1318-
sigma = sigmas[step_indices].flatten()
1319-
while len(sigma.shape) < n_dim:
1320-
sigma = sigma.unsqueeze(-1)
1321-
1322-
return sigma
1323-
1324-
def add_noise(
1325-
self,
1326-
original_samples: torch.Tensor,
1327-
noise: torch.Tensor,
1328-
timesteps: torch.Tensor,
1329-
) -> torch.Tensor:
1330-
## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
1331-
## Add noise according to flow matching.
1332-
## zt = (1 - texp) * x + texp * z1
1333-
1334-
# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
1335-
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
1336-
1337-
# timestep needs to be in [0, 1], we store them in [0, 1000]
1338-
# noisy_sample = (1 - timestep) * latent + timestep * noise
1339-
t_01 = (timesteps / 1000).to(original_samples.device)
1340-
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
1341-
1342-
# n_dim = original_samples.ndim
1343-
# sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
1344-
# noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
1345-
return noisy_model_input
1346-
1347-
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
1348-
return sample
1349-
1350-
def set_train_timesteps(self, num_timesteps, device, linear=False):
1351-
if linear:
1352-
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
1353-
self.timesteps = timesteps
1354-
return timesteps
1355-
else:
1356-
# distribute them closer to center. Inference distributes them as a bias toward first
1357-
# Generate values from 0 to 1
1358-
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
1359-
1360-
# Scale and reverse the values to go from 1000 to 0
1361-
timesteps = (1 - t) * 1000
1362-
1363-
# Sort the timesteps in descending order
1364-
timesteps, _ = torch.sort(timesteps, descending=True)
1365-
1366-
self.timesteps = timesteps.to(device=device)
1367-
1368-
return timesteps
1369-
1370-
13711268
def main(args):
13721269
if args.report_to == "wandb" and args.hub_token is not None:
13731270
raise ValueError(
@@ -1499,7 +1396,7 @@ def main(args):
14991396
)
15001397

15011398
# Load scheduler and models
1502-
noise_scheduler = CustomFlowMatchEulerDiscreteScheduler.from_pretrained(
1399+
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
15031400
args.pretrained_model_name_or_path, subfolder="scheduler"
15041401
)
15051402
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
@@ -2337,6 +2234,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
23372234
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
23382235
shutil.rmtree(removing_checkpoint)
23392236

2237+
# save embeddings
2238+
23402239
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
23412240
accelerator.save_state(save_path)
23422241
logger.info(f"Saved state to {save_path}")

0 commit comments

Comments
 (0)