Skip to content

Commit 55f1717

Browse files
committed
Use call_jax to cache HLO
1 parent 3e0cd93 commit 55f1717

File tree

4 files changed

+312
-304
lines changed

4 files changed

+312
-304
lines changed

examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_sdxl.py

Lines changed: 91 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torchvision.transforms.functional import crop
2323
from transformers import CLIPTextModel, CLIPTextModelWithProjection, AutoTokenizer
2424
from transformers.trainer_pt_utils import get_module_class_from_name
25-
from viztracer import VizTracer
25+
# from viztracer import VizTracer
2626

2727
from torch._dispatch.python import suspend_functionalization
2828
from torch._subclasses.functional_tensor import disable_functional_mode
@@ -145,7 +145,7 @@ def wrap_module(
145145

146146
def add_checkpoints(model):
147147
remat_classes = [get_module_class_from_name(model, "BasicTransformerBlock")]
148-
import pdb; pdb.set_trace()
148+
# import pdb; pdb.set_trace()
149149
def maybe_checkpoint(mod):
150150
if isinstance(mod, tuple(remat_classes)):
151151
return checkpoint_module(mod)
@@ -172,6 +172,7 @@ def __init__(
172172
self.mesh = xs.get_global_mesh()
173173
self.dataloader = iter(dataloader)
174174
self.global_step = 0
175+
# self.step_fn_compiled = torch.compile(self.step_fn, backend="openxla")
175176

176177
def run_optimizer(self):
177178
self.optimizer.step()
@@ -184,7 +185,6 @@ def start_training(self):
184185
assert measure_start_step < self.args.max_train_steps
185186
total_time = 0
186187
last_time = time.time()
187-
tracer = None
188188
for step in range(0, self.args.max_train_steps):
189189
print("step: ", step)
190190
start_time = time.time()
@@ -193,13 +193,9 @@ def start_training(self):
193193
if step == measure_start_step and PROFILE_DIR is not None:
194194
xm.wait_device_ops()
195195
xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration)
196-
if step == 15:
197-
tracer = VizTracer()
198-
else:
199-
tracer = None
196+
200197
with suspend_functionalization(), disable_functional_mode():
201198
loss = self.step_fn(
202-
tracer,
203199
batch["model_input"],
204200
batch["prompt_embeds"],
205201
batch["pooled_prompt_embeds"],
@@ -229,107 +225,91 @@ def print_loss_closure(step, loss):
229225

230226
def step_fn(
231227
self,
232-
tracer,
233228
model_input,
234229
prompt_embeds,
235230
pooled_prompt_embeds,
236231
original_sizes,
237232
crop_top_lefts
238233
):
239-
# with VizTracer(output_file="forward.json") as tracer:
240-
start_time = time.time()
241-
if tracer is not None:
242-
tracer.start()
243-
self.optimizer.zero_grad()
244-
noise = torch.randn_like(model_input).to(self.device, dtype=self.weight_dtype)
245-
bsz = model_input.shape[0]
246-
timesteps = torch.randint(
247-
0,
248-
self.noise_scheduler.config.num_train_timesteps,
249-
(bsz,),
250-
device=model_input.device,
251-
)
252-
timesteps = timesteps.long()
253-
noisy_latents = self.noise_scheduler.add_noise(model_input, noise, timesteps)
254-
noisy_latents = noisy_latents.to(self.device, dtype=self.weight_dtype)
255-
# time ids
256-
def compute_time_ids(original_size, crops_coords_top_left):
257-
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
258-
target_size = torch.tensor([self.args.resolution, self.args.resolution]).to(self.device)
259-
add_time_ids = torch.unsqueeze(torch.cat([original_size, crops_coords_top_left, target_size], axis=0), dim=0)
260-
return add_time_ids
261-
262-
add_time_ids = torch.cat(
263-
[compute_time_ids(s, c) for s, c in zip(original_sizes, crop_top_lefts)]
264-
)
265-
# Predict the noise residual
266-
unet_added_conditions = {"time_ids": add_time_ids}
267-
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
268-
# breakpoint()
269-
model_pred = self.unet(
270-
noisy_latents,
271-
timesteps,
272-
prompt_embeds,
273-
added_cond_kwargs=unet_added_conditions,
274-
return_dict=False,
275-
)[0]
276-
if self.args.prediction_type is not None:
277-
# set prediction_type of scheduler if defined
278-
self.noise_scheduler.register_to_config(prediction_type=self.args.prediction_type)
279-
if self.noise_scheduler.config.prediction_type == "epsilon":
280-
target = noise
281-
elif self.noise_scheduler.config.prediction_type == "v_prediction":
282-
target = self.noise_scheduler.get_velocity(model_input, noise, timesteps)
283-
elif self.noise_scheduler.config.prediction_type == "sample":
284-
# We set the target to latents here, but the model_pred will return the noise sample prediction.
285-
target = model_input
286-
# We will have to subtract the noise residual from the prediction to get the target sample.
287-
model_pred = model_pred - noise
288-
else:
289-
raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}")
290-
291-
if tracer:
292-
tracer.stop()
293-
tracer.save(output_file="forward.json")
294-
print(f"forward_time = {time.time()-start_time}")
295234
start_time = time.time()
296-
# with VizTracer(output_file="backward.json") as tracer:
297-
298-
if tracer:
299-
tracer.start()
300-
if self.args.snr_gamma is None:
301-
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
302-
else:
303-
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
304-
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
305-
# This is discussed in Section 4.2 of the same paper.
306-
snr = compute_snr(self.noise_scheduler, timesteps)
307-
mse_loss_weights = torch.stack([snr, self.args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
308-
dim=1
235+
with xp.Trace("optimizer_zero_grad"):
236+
self.optimizer.zero_grad(True)
237+
with xp.Trace("forward"):
238+
noise = torch.randn_like(model_input).to(self.device, dtype=self.weight_dtype)
239+
bsz = model_input.shape[0]
240+
timesteps = torch.randint(
241+
0,
242+
self.noise_scheduler.config.num_train_timesteps,
243+
(bsz,),
244+
device=model_input.device,
245+
)
246+
timesteps = timesteps.long()
247+
noisy_latents = self.noise_scheduler.add_noise(model_input, noise, timesteps)
248+
noisy_latents = noisy_latents.to(self.device, dtype=self.weight_dtype)
249+
# time ids
250+
def compute_time_ids(original_size, crops_coords_top_left):
251+
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
252+
target_size = torch.tensor([self.args.resolution, self.args.resolution]).to(self.device)
253+
add_time_ids = torch.unsqueeze(torch.cat([original_size, crops_coords_top_left, target_size], axis=0), dim=0)
254+
return add_time_ids
255+
256+
add_time_ids = torch.cat(
257+
[compute_time_ids(s, c) for s, c in zip(original_sizes, crop_top_lefts)]
258+
)
259+
# Predict the noise residual
260+
unet_added_conditions = {"time_ids": add_time_ids}
261+
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
262+
# breakpoint()
263+
model_pred = self.unet(
264+
noisy_latents,
265+
timesteps,
266+
prompt_embeds,
267+
added_cond_kwargs=unet_added_conditions,
268+
return_dict=False,
309269
)[0]
270+
if self.args.prediction_type is not None:
271+
# set prediction_type of scheduler if defined
272+
self.noise_scheduler.register_to_config(prediction_type=self.args.prediction_type)
310273
if self.noise_scheduler.config.prediction_type == "epsilon":
311-
mse_loss_weights = mse_loss_weights / snr
274+
target = noise
312275
elif self.noise_scheduler.config.prediction_type == "v_prediction":
313-
mse_loss_weights = mse_loss_weights / (snr + 1)
314-
315-
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
316-
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
317-
loss = loss.mean()
318-
loss.backward()
319-
if tracer:
320-
tracer.stop()
321-
tracer.save(output_file="backward.json")
276+
target = self.noise_scheduler.get_velocity(model_input, noise, timesteps)
277+
elif self.noise_scheduler.config.prediction_type == "sample":
278+
# We set the target to latents here, but the model_pred will return the noise sample prediction.
279+
target = model_input
280+
# We will have to subtract the noise residual from the prediction to get the target sample.
281+
model_pred = model_pred - noise
282+
else:
283+
raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}")
284+
285+
print(f"forward_time = {time.time()-start_time}")
286+
start_time = time.time()
287+
with xp.Trace("backward"):
288+
if self.args.snr_gamma is None:
289+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
290+
else:
291+
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
292+
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
293+
# This is discussed in Section 4.2 of the same paper.
294+
snr = compute_snr(self.noise_scheduler, timesteps)
295+
mse_loss_weights = torch.stack([snr, self.args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
296+
dim=1
297+
)[0]
298+
if self.noise_scheduler.config.prediction_type == "epsilon":
299+
mse_loss_weights = mse_loss_weights / snr
300+
elif self.noise_scheduler.config.prediction_type == "v_prediction":
301+
mse_loss_weights = mse_loss_weights / (snr + 1)
302+
303+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
304+
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
305+
loss = loss.mean()
306+
loss.backward()
322307
print(f"backward time = {time.time()-start_time}")
323308
start_time = time.time()
324-
# with xp.Trace("optimizer_step"):
325-
if tracer:
326-
tracer.start()
327-
self.run_optimizer()
328-
if tracer:
329-
tracer.stop()
330-
tracer.save(output_file="optimizer.json")
309+
with xp.Trace("optimizer_step"):
310+
self.run_optimizer()
331311
print(f"optimizer step = {time.time()-start_time}")
332-
return loss
312+
return model_pred
333313

334314

335315
def parse_args():
@@ -567,7 +547,12 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca
567547
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
568548
prompt_embeds_list.append(prompt_embeds)
569549

550+
570551
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1).to(dtype=dtype)
552+
print(prompt_embeds.shape)
553+
p3d = (0,0, 0, 128-77)
554+
prompt_embeds = F.pad(prompt_embeds, p3d, "constant", 0)
555+
print(prompt_embeds.shape)
571556
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1).to(dtype=dtype)
572557
return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds}
573558

@@ -580,7 +565,8 @@ def compute_vae_encodings(batch, vae):
580565
with torch.no_grad():
581566
model_input = vae.encode(pixel_values).latent_dist.sample()
582567
model_input = model_input * vae.config.scaling_factor
583-
return {"model_input": model_input}
568+
xm.mark_step()
569+
return {"model_input": model_input.cpu()}
584570

585571

586572
def load_dataset(args):
@@ -770,16 +756,20 @@ def preprocess_train(examples):
770756
)
771757
compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
772758
from datasets.fingerprint import Hasher
773-
759+
# import pdb; pdb.set_trace()
760+
old_batch_size = args.train_batch_size
761+
args.train_batch_size=21
774762
new_fingerprint = Hasher.hash(args)
763+
args.train_batch_size=64
775764
new_fingerprint_for_vae = Hasher.hash((args.pretrained_model_name_or_path, args))
765+
args.train_batch_size=old_batch_size
776766
train_dataset_with_embeddings = train_dataset.map(
777-
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
767+
compute_embeddings_fn, batched=True, batch_size=50, new_fingerprint=new_fingerprint
778768
)
779769
train_dataset_with_vae = train_dataset.map(
780770
compute_vae_encodings_fn,
781771
batched=True,
782-
batch_size=args.train_batch_size,
772+
batch_size=50,
783773
new_fingerprint=new_fingerprint_for_vae,
784774
)
785775
precomputed_dataset = concatenate_datasets(
@@ -794,14 +784,6 @@ def collate_fn(examples):
794784
crop_top_lefts = torch.stack([torch.tensor(example["crop_top_lefts"]) for example in examples])
795785
prompt_embeds = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples]).to(dtype=weight_dtype)
796786
pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples]).to(dtype=weight_dtype)
797-
# print("model_input.shape: ", model_input.shape)
798-
# print("model_input.dtype: ", model_input.dtype)
799-
# print("prompt_embeds.shape: ", prompt_embeds.shape)
800-
# print("prompt_embeds.dtype: ", prompt_embeds.dtype)
801-
# print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape)
802-
# print("pooled_prompt_embeds.dtype: ", pooled_prompt_embeds.dtype)
803-
# print("original_sizes.shape: ", original_sizes.shape)
804-
# print("crop_top_lefts.shape: ", crop_top_lefts.shape)
805787
return {
806788
"model_input": model_input,
807789
"prompt_embeds": prompt_embeds,
@@ -846,7 +828,7 @@ def collate_fn(examples):
846828
)
847829
print(f" Total optimization steps = {args.max_train_steps}")
848830

849-
unet = add_checkpoints(unet)
831+
# unet = add_checkpoints(unet)
850832

851833
trainer = TrainSD(
852834
weight_dtype=weight_dtype,

0 commit comments

Comments
 (0)