Skip to content

Commit f8a48d6

Browse files
enable TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE
1 parent fad8020 commit f8a48d6

File tree

2 files changed

+1
-11
lines changed

2 files changed

+1
-11
lines changed

examples/research_projects/pytorch_xla/training/text_to_image/README_sdxl.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
9696
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
9797
--command='
9898
export XLA_DISABLE_FUNCTIONALIZATION=1
99+
export TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE=1
99100
export PROFILE_DIR=/tmp/
100101
export CACHE_DIR=/tmp/
101102
export DATASET_NAME=lambdalabs/naruto-blip-captions

examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_sdxl.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@
3131
from diffusers.utils import is_wandb_available
3232
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
3333

34-
# torch._dynamo.config.force_parameter_static_shapes = False
35-
3634
if is_wandb_available():
3735
pass
3836

@@ -635,7 +633,6 @@ def main(args):
635633
text_encoder_2 = text_encoder_2.to(device, dtype=weight_dtype)
636634
vae = vae.to(device, dtype=weight_dtype)
637635
unet = unet.to(device, dtype=weight_dtype)
638-
#unet = torch.compile(unet, backend='openxla', dynamic=True)
639636
optimizer = setup_optimizer(unet, args)
640637
vae.requires_grad_(False)
641638
text_encoder.requires_grad_(False)
@@ -723,14 +720,6 @@ def collate_fn(examples):
723720
crop_top_lefts = torch.stack([torch.tensor(example["crop_top_lefts"]) for example in examples])
724721
prompt_embeds = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples]).to(dtype=weight_dtype)
725722
pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples]).to(dtype=weight_dtype)
726-
# print("model_input.shape: ", model_input.shape)
727-
# print("model_input.dtype: ", model_input.dtype)
728-
# print("prompt_embeds.shape: ", prompt_embeds.shape)
729-
# print("prompt_embeds.dtype: ", prompt_embeds.dtype)
730-
# print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape)
731-
# print("pooled_prompt_embeds.dtype: ", pooled_prompt_embeds.dtype)
732-
# print("original_sizes.shape: ", original_sizes.shape)
733-
# print("crop_top_lefts.shape: ", crop_top_lefts.shape)
734723
return {
735724
"model_input": model_input,
736725
"prompt_embeds": prompt_embeds,

0 commit comments

Comments
 (0)