Skip to content

Commit 2c0b2a7

Browse files
committed
Use xb.call_jax for HLO caching
1 parent a9513c1 commit 2c0b2a7

File tree

3 files changed

+243
-84
lines changed

3 files changed

+243
-84
lines changed

examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_sdxl.py

Lines changed: 107 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
import torch.nn.functional as F
1313
import torch.utils.checkpoint
14+
import torch_xla
1415
import torch_xla.core.xla_model as xm
1516
import torch_xla.debug.profiler as xp
1617
import torch_xla.distributed.parallel_loader as pl
@@ -20,6 +21,7 @@
2021
from torchvision import transforms
2122
from torchvision.transforms.functional import crop
2223
from transformers import CLIPTextModel, CLIPTextModelWithProjection, AutoTokenizer
24+
from viztracer import VizTracer
2325

2426
from diffusers import (
2527
AutoencoderKL,
@@ -35,6 +37,8 @@
3537
if is_wandb_available():
3638
pass
3739

40+
print(f"torch_xla version {torch_xla.__version__}")
41+
3842
PROFILE_DIR = os.environ.get("PROFILE_DIR", None)
3943
CACHE_DIR = os.environ.get("CACHE_DIR", None)
4044
if CACHE_DIR:
@@ -145,14 +149,22 @@ def start_training(self):
145149
print("max_train_steps: ", self.args.max_train_steps)
146150
assert measure_start_step < self.args.max_train_steps
147151
total_time = 0
152+
last_time = time.time()
153+
tracer = None
148154
for step in range(0, self.args.max_train_steps):
149155
print("step: ", step)
156+
start_time = time.time()
150157
batch = next(self.dataloader)
158+
print(f"dataloading time {time.time()-start_time}")
151159
if step == measure_start_step and PROFILE_DIR is not None:
152160
xm.wait_device_ops()
153161
xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration)
154-
last_time = time.time()
162+
if step == 15:
163+
tracer = VizTracer()
164+
else:
165+
tracer = None
155166
loss = self.step_fn(
167+
tracer,
156168
batch["model_input"],
157169
batch["prompt_embeds"],
158170
batch["pooled_prompt_embeds"],
@@ -182,84 +194,106 @@ def print_loss_closure(step, loss):
182194

183195
def step_fn(
184196
self,
197+
tracer,
185198
model_input,
186199
prompt_embeds,
187200
pooled_prompt_embeds,
188201
original_sizes,
189202
crop_top_lefts
190203
):
191-
with xp.Trace("model.forward"):
192-
self.optimizer.zero_grad()
193-
194-
195-
noise = torch.randn_like(model_input).to(self.device, dtype=self.weight_dtype)
196-
bsz = model_input.shape[0]
197-
timesteps = torch.randint(
198-
0,
199-
self.noise_scheduler.config.num_train_timesteps,
200-
(bsz,),
201-
device=model_input.device,
202-
)
203-
timesteps = timesteps.long()
204-
noisy_latents = self.noise_scheduler.add_noise(model_input, noise, timesteps)
205-
noisy_latents = noisy_latents.to(self.device, dtype=self.weight_dtype)
206-
# time ids
207-
def compute_time_ids(original_size, crops_coords_top_left):
208-
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
209-
target_size = torch.tensor([self.args.resolution, self.args.resolution]).to(self.device)
210-
add_time_ids = torch.unsqueeze(torch.cat([original_size, crops_coords_top_left, target_size], axis=0), dim=0)
211-
return add_time_ids
212-
213-
add_time_ids = torch.cat(
214-
[compute_time_ids(s, c) for s, c in zip(original_sizes, crop_top_lefts)]
215-
)
216-
# Predict the noise residual
217-
unet_added_conditions = {"time_ids": add_time_ids}
218-
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
219-
# breakpoint()
220-
model_pred = self.unet(
221-
noisy_latents,
222-
timesteps,
223-
prompt_embeds,
224-
added_cond_kwargs=unet_added_conditions,
225-
return_dict=False,
204+
# with VizTracer(output_file="forward.json") as tracer:
205+
start_time = time.time()
206+
if tracer is not None:
207+
tracer.start()
208+
self.optimizer.zero_grad()
209+
noise = torch.randn_like(model_input).to(self.device, dtype=self.weight_dtype)
210+
bsz = model_input.shape[0]
211+
timesteps = torch.randint(
212+
0,
213+
self.noise_scheduler.config.num_train_timesteps,
214+
(bsz,),
215+
device=model_input.device,
216+
)
217+
timesteps = timesteps.long()
218+
noisy_latents = self.noise_scheduler.add_noise(model_input, noise, timesteps)
219+
noisy_latents = noisy_latents.to(self.device, dtype=self.weight_dtype)
220+
# time ids
221+
def compute_time_ids(original_size, crops_coords_top_left):
222+
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
223+
target_size = torch.tensor([self.args.resolution, self.args.resolution]).to(self.device)
224+
add_time_ids = torch.unsqueeze(torch.cat([original_size, crops_coords_top_left, target_size], axis=0), dim=0)
225+
return add_time_ids
226+
227+
add_time_ids = torch.cat(
228+
[compute_time_ids(s, c) for s, c in zip(original_sizes, crop_top_lefts)]
229+
)
230+
# Predict the noise residual
231+
unet_added_conditions = {"time_ids": add_time_ids}
232+
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
233+
# breakpoint()
234+
model_pred = self.unet(
235+
noisy_latents,
236+
timesteps,
237+
prompt_embeds,
238+
added_cond_kwargs=unet_added_conditions,
239+
return_dict=False,
240+
)[0]
241+
if self.args.prediction_type is not None:
242+
# set prediction_type of scheduler if defined
243+
self.noise_scheduler.register_to_config(prediction_type=self.args.prediction_type)
244+
if self.noise_scheduler.config.prediction_type == "epsilon":
245+
target = noise
246+
elif self.noise_scheduler.config.prediction_type == "v_prediction":
247+
target = self.noise_scheduler.get_velocity(model_input, noise, timesteps)
248+
elif self.noise_scheduler.config.prediction_type == "sample":
249+
# We set the target to latents here, but the model_pred will return the noise sample prediction.
250+
target = model_input
251+
# We will have to subtract the noise residual from the prediction to get the target sample.
252+
model_pred = model_pred - noise
253+
else:
254+
raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}")
255+
256+
if tracer:
257+
tracer.stop()
258+
tracer.save(output_file="forward.json")
259+
print(f"forward_time = {time.time()-start_time}")
260+
start_time = time.time()
261+
# with VizTracer(output_file="backward.json") as tracer:
262+
263+
if tracer:
264+
tracer.start()
265+
if self.args.snr_gamma is None:
266+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
267+
else:
268+
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
269+
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
270+
# This is discussed in Section 4.2 of the same paper.
271+
snr = compute_snr(self.noise_scheduler, timesteps)
272+
mse_loss_weights = torch.stack([snr, self.args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
273+
dim=1
226274
)[0]
227-
if self.args.prediction_type is not None:
228-
# set prediction_type of scheduler if defined
229-
self.noise_scheduler.register_to_config(prediction_type=self.args.prediction_type)
230275
if self.noise_scheduler.config.prediction_type == "epsilon":
231-
target = noise
276+
mse_loss_weights = mse_loss_weights / snr
232277
elif self.noise_scheduler.config.prediction_type == "v_prediction":
233-
target = self.noise_scheduler.get_velocity(model_input, noise, timesteps)
234-
elif self.noise_scheduler.config.prediction_type == "sample":
235-
# We set the target to latents here, but the model_pred will return the noise sample prediction.
236-
target = model_input
237-
# We will have to subtract the noise residual from the prediction to get the target sample.
238-
model_pred = model_pred - noise
239-
else:
240-
raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}")
241-
with xp.Trace("model.backward"):
242-
if self.args.snr_gamma is None:
243-
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
244-
else:
245-
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
246-
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
247-
# This is discussed in Section 4.2 of the same paper.
248-
snr = compute_snr(self.noise_scheduler, timesteps)
249-
mse_loss_weights = torch.stack([snr, self.args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
250-
dim=1
251-
)[0]
252-
if self.noise_scheduler.config.prediction_type == "epsilon":
253-
mse_loss_weights = mse_loss_weights / snr
254-
elif self.noise_scheduler.config.prediction_type == "v_prediction":
255-
mse_loss_weights = mse_loss_weights / (snr + 1)
256-
257-
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
258-
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
259-
loss = loss.mean()
260-
loss.backward()
261-
with xp.Trace("optimizer_step"):
262-
self.run_optimizer()
278+
mse_loss_weights = mse_loss_weights / (snr + 1)
279+
280+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
281+
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
282+
loss = loss.mean()
283+
loss.backward()
284+
if tracer:
285+
tracer.stop()
286+
tracer.save(output_file="backward.json")
287+
print(f"backward time = {time.time()-start_time}")
288+
start_time = time.time()
289+
# with xp.Trace("optimizer_step"):
290+
if tracer:
291+
tracer.start()
292+
self.run_optimizer()
293+
if tracer:
294+
tracer.stop()
295+
tracer.save(output_file="optimizer.json")
296+
print(f"optimizer step = {time.time()-start_time}")
263297
return loss
264298

265299

@@ -559,11 +593,11 @@ def get_column_names(dataset, args):
559593

560594
def main(args):
561595
args = parse_args()
562-
cache_path = Path("/tmp/data/compiler_cache")
596+
cache_path = Path(os.environ.get('CACHE_DIR', "/mnt/bbahl/xla_cache"))
563597
cache_path.mkdir(parents=True, exist_ok=True)
564598
xr.initialize_cache(str(cache_path), readonly=False)
565599

566-
_ = xp.start_server(PORT)
600+
server = xp.start_server(PORT)
567601

568602
num_devices = xr.global_runtime_device_count()
569603
mesh = xs.get_1d_mesh("data")
@@ -631,7 +665,7 @@ def main(args):
631665
if args.mixed_precision == "bf16":
632666
weight_dtype = torch.bfloat16
633667

634-
device = xm.xla_device()
668+
device = torch_xla.device()
635669

636670
# Move text_encode and vae to device and cast to weight_dtype
637671
text_encoder = text_encoder.to(device, dtype=weight_dtype)

0 commit comments

Comments
 (0)