|
26 | 26 |
|
27 | 27 | SAVE_AND_SAMPLE_EVERY = 1000 |
28 | 28 | UPDATE_EMA_EVERY = 10 |
29 | | -EXTS = ['jpg', 'png'] |
| 29 | +EXTS = ['jpg', 'jpeg', 'png'] |
30 | 30 |
|
31 | 31 | # helpers functions |
32 | 32 |
|
@@ -263,24 +263,36 @@ def noise_like(shape, device, repeat=False): |
263 | 263 | noise = lambda: torch.randn(shape, device=device) |
264 | 264 | return repeat_noise() if repeat else noise() |
265 | 265 |
|
| 266 | +def cosine_beta_schedule(timesteps, s = 0.008): |
| 267 | + """ |
| 268 | + cosine schedule |
| 269 | + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ |
| 270 | + """ |
| 271 | + steps = timesteps + 1 |
| 272 | + x = np.linspace(0, steps, steps) |
| 273 | + alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 |
| 274 | + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] |
| 275 | + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) |
| 276 | + return np.clip(betas, a_min = 0, a_max = 0.999) |
| 277 | + |
266 | 278 | class GaussianDiffusion(nn.Module): |
267 | | - def __init__(self, denoise_fn, beta_start=0.0001, beta_end=0.02, num_diffusion_timesteps=1000, loss_type='l1', betas = None): |
| 279 | + def __init__(self, denoise_fn, timesteps=1000, loss_type='l1', betas = None): |
268 | 280 | super().__init__() |
269 | 281 | self.denoise_fn = denoise_fn |
270 | 282 |
|
271 | 283 | if exists(betas): |
272 | | - self.np_betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas |
| 284 | + betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas |
273 | 285 | else: |
274 | | - self.np_betas = betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps).astype(np.float64) |
275 | | - |
276 | | - timesteps, = betas.shape |
277 | | - self.num_timesteps = int(timesteps) |
278 | | - self.loss_type = loss_type |
| 286 | + betas = cosine_beta_schedule(timesteps) |
279 | 287 |
|
280 | 288 | alphas = 1. - betas |
281 | 289 | alphas_cumprod = np.cumprod(alphas, axis=0) |
282 | 290 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) |
283 | 291 |
|
| 292 | + timesteps, = betas.shape |
| 293 | + self.num_timesteps = int(timesteps) |
| 294 | + self.loss_type = loss_type |
| 295 | + |
284 | 296 | to_torch = partial(torch.tensor, dtype=torch.float32) |
285 | 297 |
|
286 | 298 | self.register_buffer('betas', to_torch(betas)) |
|
0 commit comments