|
11 | 11 | import torch |
12 | 12 | import torch.nn.functional as F |
13 | 13 | import torch.utils.checkpoint |
| 14 | +import torch_xla |
14 | 15 | import torch_xla.core.xla_model as xm |
15 | 16 | import torch_xla.debug.profiler as xp |
16 | 17 | import torch_xla.distributed.parallel_loader as pl |
|
20 | 21 | from torchvision import transforms |
21 | 22 | from torchvision.transforms.functional import crop |
22 | 23 | from transformers import CLIPTextModel, CLIPTextModelWithProjection, AutoTokenizer |
| 24 | +from viztracer import VizTracer |
23 | 25 |
|
24 | 26 | from diffusers import ( |
25 | 27 | AutoencoderKL, |
|
35 | 37 | if is_wandb_available(): |
36 | 38 | pass |
37 | 39 |
|
| 40 | +print(f"torch_xla version {torch_xla.__version__}") |
| 41 | + |
38 | 42 | PROFILE_DIR = os.environ.get("PROFILE_DIR", None) |
39 | 43 | CACHE_DIR = os.environ.get("CACHE_DIR", None) |
40 | 44 | if CACHE_DIR: |
@@ -145,14 +149,22 @@ def start_training(self): |
145 | 149 | print("max_train_steps: ", self.args.max_train_steps) |
146 | 150 | assert measure_start_step < self.args.max_train_steps |
147 | 151 | total_time = 0 |
| 152 | + last_time = time.time() |
| 153 | + tracer = None |
148 | 154 | for step in range(0, self.args.max_train_steps): |
149 | 155 | print("step: ", step) |
| 156 | + start_time = time.time() |
150 | 157 | batch = next(self.dataloader) |
| 158 | + print(f"dataloading time {time.time()-start_time}") |
151 | 159 | if step == measure_start_step and PROFILE_DIR is not None: |
152 | 160 | xm.wait_device_ops() |
153 | 161 | xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration) |
154 | | - last_time = time.time() |
| 162 | + if step == 15: |
| 163 | + tracer = VizTracer() |
| 164 | + else: |
| 165 | + tracer = None |
155 | 166 | loss = self.step_fn( |
| 167 | + tracer, |
156 | 168 | batch["model_input"], |
157 | 169 | batch["prompt_embeds"], |
158 | 170 | batch["pooled_prompt_embeds"], |
@@ -182,84 +194,106 @@ def print_loss_closure(step, loss): |
182 | 194 |
|
183 | 195 | def step_fn( |
184 | 196 | self, |
| 197 | + tracer, |
185 | 198 | model_input, |
186 | 199 | prompt_embeds, |
187 | 200 | pooled_prompt_embeds, |
188 | 201 | original_sizes, |
189 | 202 | crop_top_lefts |
190 | 203 | ): |
191 | | - with xp.Trace("model.forward"): |
192 | | - self.optimizer.zero_grad() |
193 | | - |
194 | | - |
195 | | - noise = torch.randn_like(model_input).to(self.device, dtype=self.weight_dtype) |
196 | | - bsz = model_input.shape[0] |
197 | | - timesteps = torch.randint( |
198 | | - 0, |
199 | | - self.noise_scheduler.config.num_train_timesteps, |
200 | | - (bsz,), |
201 | | - device=model_input.device, |
202 | | - ) |
203 | | - timesteps = timesteps.long() |
204 | | - noisy_latents = self.noise_scheduler.add_noise(model_input, noise, timesteps) |
205 | | - noisy_latents = noisy_latents.to(self.device, dtype=self.weight_dtype) |
206 | | - # time ids |
207 | | - def compute_time_ids(original_size, crops_coords_top_left): |
208 | | - # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids |
209 | | - target_size = torch.tensor([self.args.resolution, self.args.resolution]).to(self.device) |
210 | | - add_time_ids = torch.unsqueeze(torch.cat([original_size, crops_coords_top_left, target_size], axis=0), dim=0) |
211 | | - return add_time_ids |
212 | | - |
213 | | - add_time_ids = torch.cat( |
214 | | - [compute_time_ids(s, c) for s, c in zip(original_sizes, crop_top_lefts)] |
215 | | - ) |
216 | | - # Predict the noise residual |
217 | | - unet_added_conditions = {"time_ids": add_time_ids} |
218 | | - unet_added_conditions.update({"text_embeds": pooled_prompt_embeds}) |
219 | | - # breakpoint() |
220 | | - model_pred = self.unet( |
221 | | - noisy_latents, |
222 | | - timesteps, |
223 | | - prompt_embeds, |
224 | | - added_cond_kwargs=unet_added_conditions, |
225 | | - return_dict=False, |
| 204 | + # with VizTracer(output_file="forward.json") as tracer: |
| 205 | + start_time = time.time() |
| 206 | + if tracer is not None: |
| 207 | + tracer.start() |
| 208 | + self.optimizer.zero_grad() |
| 209 | + noise = torch.randn_like(model_input).to(self.device, dtype=self.weight_dtype) |
| 210 | + bsz = model_input.shape[0] |
| 211 | + timesteps = torch.randint( |
| 212 | + 0, |
| 213 | + self.noise_scheduler.config.num_train_timesteps, |
| 214 | + (bsz,), |
| 215 | + device=model_input.device, |
| 216 | + ) |
| 217 | + timesteps = timesteps.long() |
| 218 | + noisy_latents = self.noise_scheduler.add_noise(model_input, noise, timesteps) |
| 219 | + noisy_latents = noisy_latents.to(self.device, dtype=self.weight_dtype) |
| 220 | + # time ids |
| 221 | + def compute_time_ids(original_size, crops_coords_top_left): |
| 222 | + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids |
| 223 | + target_size = torch.tensor([self.args.resolution, self.args.resolution]).to(self.device) |
| 224 | + add_time_ids = torch.unsqueeze(torch.cat([original_size, crops_coords_top_left, target_size], axis=0), dim=0) |
| 225 | + return add_time_ids |
| 226 | + |
| 227 | + add_time_ids = torch.cat( |
| 228 | + [compute_time_ids(s, c) for s, c in zip(original_sizes, crop_top_lefts)] |
| 229 | + ) |
| 230 | + # Predict the noise residual |
| 231 | + unet_added_conditions = {"time_ids": add_time_ids} |
| 232 | + unet_added_conditions.update({"text_embeds": pooled_prompt_embeds}) |
| 233 | + # breakpoint() |
| 234 | + model_pred = self.unet( |
| 235 | + noisy_latents, |
| 236 | + timesteps, |
| 237 | + prompt_embeds, |
| 238 | + added_cond_kwargs=unet_added_conditions, |
| 239 | + return_dict=False, |
| 240 | + )[0] |
| 241 | + if self.args.prediction_type is not None: |
| 242 | + # set prediction_type of scheduler if defined |
| 243 | + self.noise_scheduler.register_to_config(prediction_type=self.args.prediction_type) |
| 244 | + if self.noise_scheduler.config.prediction_type == "epsilon": |
| 245 | + target = noise |
| 246 | + elif self.noise_scheduler.config.prediction_type == "v_prediction": |
| 247 | + target = self.noise_scheduler.get_velocity(model_input, noise, timesteps) |
| 248 | + elif self.noise_scheduler.config.prediction_type == "sample": |
| 249 | + # We set the target to latents here, but the model_pred will return the noise sample prediction. |
| 250 | + target = model_input |
| 251 | + # We will have to subtract the noise residual from the prediction to get the target sample. |
| 252 | + model_pred = model_pred - noise |
| 253 | + else: |
| 254 | + raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}") |
| 255 | + |
| 256 | + if tracer: |
| 257 | + tracer.stop() |
| 258 | + tracer.save(output_file="forward.json") |
| 259 | + print(f"forward_time = {time.time()-start_time}") |
| 260 | + start_time = time.time() |
| 261 | + # with VizTracer(output_file="backward.json") as tracer: |
| 262 | + |
| 263 | + if tracer: |
| 264 | + tracer.start() |
| 265 | + if self.args.snr_gamma is None: |
| 266 | + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 267 | + else: |
| 268 | + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. |
| 269 | + # Since we predict the noise instead of x_0, the original formulation is slightly changed. |
| 270 | + # This is discussed in Section 4.2 of the same paper. |
| 271 | + snr = compute_snr(self.noise_scheduler, timesteps) |
| 272 | + mse_loss_weights = torch.stack([snr, self.args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( |
| 273 | + dim=1 |
226 | 274 | )[0] |
227 | | - if self.args.prediction_type is not None: |
228 | | - # set prediction_type of scheduler if defined |
229 | | - self.noise_scheduler.register_to_config(prediction_type=self.args.prediction_type) |
230 | 275 | if self.noise_scheduler.config.prediction_type == "epsilon": |
231 | | - target = noise |
| 276 | + mse_loss_weights = mse_loss_weights / snr |
232 | 277 | elif self.noise_scheduler.config.prediction_type == "v_prediction": |
233 | | - target = self.noise_scheduler.get_velocity(model_input, noise, timesteps) |
234 | | - elif self.noise_scheduler.config.prediction_type == "sample": |
235 | | - # We set the target to latents here, but the model_pred will return the noise sample prediction. |
236 | | - target = model_input |
237 | | - # We will have to subtract the noise residual from the prediction to get the target sample. |
238 | | - model_pred = model_pred - noise |
239 | | - else: |
240 | | - raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}") |
241 | | - with xp.Trace("model.backward"): |
242 | | - if self.args.snr_gamma is None: |
243 | | - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
244 | | - else: |
245 | | - # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. |
246 | | - # Since we predict the noise instead of x_0, the original formulation is slightly changed. |
247 | | - # This is discussed in Section 4.2 of the same paper. |
248 | | - snr = compute_snr(self.noise_scheduler, timesteps) |
249 | | - mse_loss_weights = torch.stack([snr, self.args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( |
250 | | - dim=1 |
251 | | - )[0] |
252 | | - if self.noise_scheduler.config.prediction_type == "epsilon": |
253 | | - mse_loss_weights = mse_loss_weights / snr |
254 | | - elif self.noise_scheduler.config.prediction_type == "v_prediction": |
255 | | - mse_loss_weights = mse_loss_weights / (snr + 1) |
256 | | - |
257 | | - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
258 | | - loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights |
259 | | - loss = loss.mean() |
260 | | - loss.backward() |
261 | | - with xp.Trace("optimizer_step"): |
262 | | - self.run_optimizer() |
| 278 | + mse_loss_weights = mse_loss_weights / (snr + 1) |
| 279 | + |
| 280 | + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
| 281 | + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights |
| 282 | + loss = loss.mean() |
| 283 | + loss.backward() |
| 284 | + if tracer: |
| 285 | + tracer.stop() |
| 286 | + tracer.save(output_file="backward.json") |
| 287 | + print(f"backward time = {time.time()-start_time}") |
| 288 | + start_time = time.time() |
| 289 | + # with xp.Trace("optimizer_step"): |
| 290 | + if tracer: |
| 291 | + tracer.start() |
| 292 | + self.run_optimizer() |
| 293 | + if tracer: |
| 294 | + tracer.stop() |
| 295 | + tracer.save(output_file="optimizer.json") |
| 296 | + print(f"optimizer step = {time.time()-start_time}") |
263 | 297 | return loss |
264 | 298 |
|
265 | 299 |
|
@@ -559,11 +593,11 @@ def get_column_names(dataset, args): |
559 | 593 |
|
560 | 594 | def main(args): |
561 | 595 | args = parse_args() |
562 | | - cache_path = Path("/tmp/data/compiler_cache") |
| 596 | + cache_path = Path(os.environ.get('CACHE_DIR', "/mnt/bbahl/xla_cache")) |
563 | 597 | cache_path.mkdir(parents=True, exist_ok=True) |
564 | 598 | xr.initialize_cache(str(cache_path), readonly=False) |
565 | 599 |
|
566 | | - _ = xp.start_server(PORT) |
| 600 | + server = xp.start_server(PORT) |
567 | 601 |
|
568 | 602 | num_devices = xr.global_runtime_device_count() |
569 | 603 | mesh = xs.get_1d_mesh("data") |
@@ -631,7 +665,7 @@ def main(args): |
631 | 665 | if args.mixed_precision == "bf16": |
632 | 666 | weight_dtype = torch.bfloat16 |
633 | 667 |
|
634 | | - device = xm.xla_device() |
| 668 | + device = torch_xla.device() |
635 | 669 |
|
636 | 670 | # Move text_encode and vae to device and cast to weight_dtype |
637 | 671 | text_encoder = text_encoder.to(device, dtype=weight_dtype) |
|
0 commit comments