-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathedm_pipeline.py
More file actions
426 lines (362 loc) · 16.8 KB
/
edm_pipeline.py
File metadata and controls
426 lines (362 loc) · 16.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Optional, Tuple
import numpy as np
import torch
from megatron.core import parallel_state
from torch import Tensor
from dfm.src.common.utils.batch_ops import batch_mul
from dfm.src.common.utils.torch_split_tensor_for_cp import cat_outputs_cp
from dfm.src.megatron.model.dit.edm.edm_utils import EDMSDE, EDMSampler, EDMScaling
class EDMPipeline:
"""
EDMPipeline is a class that implements a diffusion model pipeline for video generation. It includes methods for
initializing the pipeline, encoding and decoding video data, performing training steps, denoising, and generating
samples.
Attributes:
p_mean: Mean for SDE process.
p_std: Standard deviation for SDE process.
sigma_max: Maximum noise level.
sigma_min: Minimum noise level.
_noise_generator: Generator for noise.
_noise_level_generator: Generator for noise levels.
sde: SDE process.
sampler: Sampler for the diffusion model.
scaling: Scaling for EDM.
input_data_key: Key for input video data.
input_image_key: Key for input image data.
tensor_kwargs: Tensor keyword arguments.
loss_reduce: Method for reducing loss.
loss_scale: Scale factor for loss.
aesthetic_finetuning: Aesthetic finetuning parameter.
camera_sample_weight: Camera sample weight parameter.
loss_mask_enabled: Flag for enabling loss mask.
Methods:
noise_level_generator: Returns the noise level generator.
_initialize_generators: Initializes noise and noise-level generators.
encode: Encodes input tensor using the video tokenizer.
decode: Decodes latent tensor using video tokenizer.
training_step: Performs a single training step for the diffusion model.
denoise: Performs denoising on the input noise data, noise level, and condition.
compute_loss_with_epsilon_and_sigma: Computes the loss for training.
get_per_sigma_loss_weights: Returns loss weights per sigma noise level.
get_condition_uncondition: Returns conditioning and unconditioning for classifier-free guidance.
get_x0_fn_from_batch: Creates a function to generate denoised predictions with the sampler.
generate_samples_from_batch: Generates samples based on input data batch.
_normalize_video_databatch_inplace: Normalizes video data in-place on a CUDA device to [-1, 1].
draw_training_sigma_and_epsilon: Draws training noise (epsilon) and noise levels (sigma).
random_dropout_input: Applies random dropout to the input tensor.
get_data_and_condition: Retrieves data and conditioning for model input.
"""
def __init__(
self,
vae=None,
p_mean=0.0,
p_std=1.0,
sigma_max=80,
sigma_min=0.0002,
sigma_data=0.5,
seed=1234,
):
"""
Initializes the EDM pipeline with the given parameters.
Args:
net: The DiT model.
vae: The Video Tokenizer (optional).
p_mean (float): Mean for the SDE.
p_std (float): Standard deviation for the SDE.
sigma_max (float): Maximum sigma value for the SDE.
sigma_min (float): Minimum sigma value for the SDE.
sigma_data (float): Sigma value for EDM scaling.
seed (int): Random seed for reproducibility.
Attributes:
vae: The Video Tokenizer.
net: The DiT model.
p_mean (float): Mean for the SDE.
p_std (float): Standard deviation for the SDE.
sigma_max (float): Maximum sigma value for the SDE.
sigma_min (float): Minimum sigma value for the SDE.
sigma_data (float): Sigma value for EDM scaling.
seed (int): Random seed for reproducibility.
_noise_generator: Placeholder for noise generator.
_noise_level_generator: Placeholder for noise level generator.
sde: Instance of EDMSDE initialized with p_mean, p_std, sigma_max, and sigma_min.
sampler: Instance of EDMSampler.
scaling: Instance of EDMScaling initialized with sigma_data.
input_data_key (str): Key for input data.
input_image_key (str): Key for input images.
tensor_kwargs (dict): Tensor keyword arguments for device and dtype.
loss_reduce (str): Method to reduce loss ('mean' or other).
loss_scale (float): Scale factor for loss.
"""
self.vae = vae
self.p_mean = p_mean
self.p_std = p_std
self.sigma_max = sigma_max
self.sigma_min = sigma_min
self.sigma_data = sigma_data
self.seed = seed
self._noise_generator = None
self._noise_level_generator = None
self.sde = EDMSDE(p_mean, p_std, sigma_max, sigma_min)
self.sampler = EDMSampler()
self.scaling = EDMScaling(sigma_data)
self.input_data_key = "video"
self.input_image_key = "images_1024"
self.tensor_kwargs = {"device": "cuda", "dtype": torch.bfloat16}
self.loss_reduce = "mean"
self.loss_scale = 1.0
@property
def noise_level_generator(self):
"""
Generates noise levels for the EDM pipeline.
Returns:
Callable: A function or generator that produces noise levels.
"""
return self._noise_level_generator
def _initialize_generators(self):
"""
Initializes the random number generators for noise and noise level.
This method sets up two generators:
1. A PyTorch generator for noise, seeded with a combination of the base seed and the data parallel rank.
2. A NumPy generator for noise levels, seeded similarly but without considering context parallel rank.
Returns:
None
"""
noise_seed = self.seed + 100 * parallel_state.get_data_parallel_rank(with_context_parallel=True)
noise_level_seed = self.seed + 100 * parallel_state.get_data_parallel_rank(with_context_parallel=False)
self._noise_generator = torch.Generator(device="cuda")
self._noise_generator.manual_seed(noise_seed)
self._noise_level_generator = np.random.default_rng(noise_level_seed)
self.sde._generator = self._noise_level_generator
def training_step(
self, model, data_batch: dict[str, torch.Tensor], iteration: int
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
"""
Performs a single training step for the diffusion model.
This method is responsible for executing one iteration of the model's training. It involves:
1. Adding noise to the input data using the SDE process.
2. Passing the noisy data through the network to generate predictions.
3. Computing the loss based on the difference between the predictions and the original data.
Args:
data_batch (dict): raw data batch draw from the training data loader.
iteration (int): Current iteration number.
Returns:
A tuple with the output batch and the computed loss.
"""
# import pdb; pdb.set_trace()
# Get the input data to noise and denoise~(image, video) and the corresponding conditioner.
self.net = model
x0, condition = self.get_data_and_condition(data_batch)
# Sample pertubation noise levels and N(0, 1) noises
sigma, epsilon = self.draw_training_sigma_and_epsilon(x0.size(), condition)
if parallel_state.is_pipeline_last_stage():
output_batch, pred_mse, edm_loss = self.compute_loss_with_epsilon_and_sigma(x0, condition, epsilon, sigma)
return output_batch, edm_loss
else:
net_output = self.compute_loss_with_epsilon_and_sigma(x0, condition, epsilon, sigma)
return net_output
def denoise(self, xt: torch.Tensor, sigma: torch.Tensor, condition: dict[str, torch.Tensor]):
"""
Performs denoising on the input noise data, noise level, and condition
Args:
xt (torch.Tensor): The input noise data.
sigma (torch.Tensor): The noise level.
condition (dict[str, torch.Tensor]): conditional information
Returns:
Predicted clean data (x0) and noise (eps_pred).
"""
xt = xt.to(**self.tensor_kwargs)
sigma = sigma.to(**self.tensor_kwargs)
# get precondition for the network
c_skip, c_out, c_in, c_noise = self.scaling(sigma=sigma)
net_output = self.net(
x=batch_mul(c_in, xt), # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf
timesteps=c_noise, # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf
**condition,
)
if not parallel_state.is_pipeline_last_stage():
return net_output
x0_pred = batch_mul(c_skip, xt) + batch_mul(c_out, net_output)
# get noise prediction based on sde
eps_pred = batch_mul(xt - x0_pred, 1.0 / sigma)
return x0_pred, eps_pred
def compute_loss_with_epsilon_and_sigma(
self,
x0: torch.Tensor,
condition: dict[str, torch.Tensor],
epsilon: torch.Tensor,
sigma: torch.Tensor,
):
"""
Computes the loss for training.
Args:
data_batch: Batch of input data.
x0_from_data_batch: Raw input tensor.
x0: Latent tensor.
condition: Conditional input data.
epsilon: Noise tensor.
sigma: Noise level tensor.
Returns:
The computed loss.
"""
# Get the mean and stand deviation of the marginal probability distribution.
mean, std = self.sde.marginal_prob(x0, sigma)
# Generate noisy observations
xt = mean + batch_mul(std, epsilon) # corrupted data
if parallel_state.is_pipeline_last_stage():
# make prediction
x0_pred, eps_pred = self.denoise(xt, sigma, condition)
# loss weights for different noise levels
weights_per_sigma = self.get_per_sigma_loss_weights(sigma=sigma)
pred_mse = (x0 - x0_pred) ** 2
edm_loss = batch_mul(pred_mse, weights_per_sigma)
output_batch = {
"x0": x0,
"xt": xt,
"sigma": sigma,
"weights_per_sigma": weights_per_sigma,
"condition": condition,
"model_pred": {"x0_pred": x0_pred, "eps_pred": eps_pred},
"mse_loss": pred_mse.mean(),
"edm_loss": edm_loss.mean(),
}
return output_batch, pred_mse, edm_loss
else:
# make prediction
x0_pred = self.denoise(xt, sigma, condition)
return x0_pred.contiguous()
def get_per_sigma_loss_weights(self, sigma: torch.Tensor):
"""
Args:
sigma (tensor): noise level
Returns:
loss weights per sigma noise level
"""
return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
def get_condition_uncondition(self, data_batch: Dict):
"""Returns conditioning and unconditioning for classifier-free guidance."""
_, condition = self.get_data_and_condition(data_batch, dropout_rate=0.0)
if "neg_context_embeddings" in data_batch:
data_batch["context_embeddings"] = data_batch["neg_context_embeddings"]
data_batch["context_mask"] = data_batch["context_mask"]
_, uncondition = self.get_data_and_condition(data_batch, dropout_rate=1.0)
else:
_, uncondition = self.get_data_and_condition(data_batch, dropout_rate=1.0)
return condition, uncondition
def get_x0_fn_from_batch(
self,
data_batch: Dict,
guidance: float = 1.5,
is_negative_prompt: bool = False,
) -> Callable:
"""
Creates a function to generate denoised predictions with the sampler.
Args:
data_batch: Batch of input data.
guidance: Guidance scale factor.
is_negative_prompt: Whether to use negative prompts.
Returns:
A callable to predict clean data (x0).
"""
condition, uncondition = self.get_condition_uncondition(data_batch)
def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
cond_x0, _ = self.denoise(noise_x, sigma, condition)
uncond_x0, _ = self.denoise(noise_x, sigma, uncondition)
return cond_x0 + guidance * (cond_x0 - uncond_x0)
return x0_fn
def generate_samples_from_batch(
self,
model,
data_batch: Dict,
guidance: float = 1.5,
state_shape: Tuple | None = None,
is_negative_prompt: bool = False,
num_steps: int = 35,
) -> Tensor:
"""
Generates samples based on input data batch.
Args:
data_batch: Batch of input data.
guidance: Guidance scale factor.
state_shape: Shape of the state.
is_negative_prompt: Whether to use negative prompts.
num_steps: Number of steps for sampling.
solver_option: SDE Solver option.
Returns:
Generated samples from diffusion model.
"""
self.net = model
cp_enabled = parallel_state.get_context_parallel_world_size() > 1
if self._noise_generator is None:
self._initialize_generators()
x0_fn = self.get_x0_fn_from_batch(data_batch, guidance, is_negative_prompt=is_negative_prompt)
x_sigma_max = (
torch.randn(state_shape, **self.tensor_kwargs, generator=self._noise_generator) * self.sde.sigma_max
)
samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max)
if cp_enabled:
cp_group = parallel_state.get_context_parallel_group()
thd_cu_seqlen_q_padded = data_batch["packed_seq_params"]["self_attention"].cu_seqlens_q_padded
samples = cat_outputs_cp(samples, seq_dim=1, cp_group=cp_group, thd_cu_seqlens=thd_cu_seqlen_q_padded)
return samples
def draw_training_sigma_and_epsilon(self, x0_size: int, condition: Any) -> torch.Tensor:
"""
Draws training noise (epsilon) and noise levels (sigma).
Args:
x0_size: Shape of the input tensor.
condition: Conditional input (unused).
Returns:
Noise level (sigma) and noise (epsilon).
"""
del condition
batch_size = x0_size[0]
if self._noise_generator is None:
self._initialize_generators()
epsilon = torch.randn(x0_size, **self.tensor_kwargs, generator=self._noise_generator)
return self.sde.sample_t(batch_size).to(**self.tensor_kwargs), epsilon
def random_dropout_input(self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None) -> torch.Tensor:
"""
Applies random dropout to the input tensor.
Args:
in_tensor: Input tensor.
dropout_rate: Dropout probability (optional).
Returns:
Conditioning with random dropout applied.
"""
dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate
return batch_mul(
torch.bernoulli((1.0 - dropout_rate) * torch.ones(in_tensor.shape[0])).type_as(in_tensor),
in_tensor,
)
def get_data_and_condition(self, data_batch: dict[str, Tensor], dropout_rate=0.2) -> Tuple[Tensor]:
"""
Retrieves data and conditioning for model input.
Args:
data_batch: Batch of input data.
dropout_rate: Dropout probability for conditioning.
Returns:
Raw data, latent data, and conditioning information.
"""
# Latent state
latent_state = data_batch["video"] * self.sigma_data
condition = {} # Create a new dictionary for condition
# Copy all keys from data_batch except 'video'
for key, value in data_batch.items():
if key not in ["video", "context_embeddings"]:
condition[key] = value
condition["crossattn_emb"] = self.random_dropout_input(
data_batch["context_embeddings"], dropout_rate=dropout_rate
)
return latent_state, condition