Skip to content

Commit 9062c4f

Browse files
move vae encoding to data pipeline.
1 parent 9ccc3a4 commit 9062c4f

File tree

2 files changed

+37
-33
lines changed

2 files changed

+37
-33
lines changed

examples/research_projects/pytorch_xla/training/text_to_image/README_flux.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ are fixed.
9595
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
9696
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
9797
--command='
98-
export XLA_DISABLE_FUNCTIONALIZATION=1
9998
export TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE=1
10099
export PROFILE_DIR=/tmp/
101100
export CACHE_DIR=/tmp/

examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_flux.py

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ def __init__(
121121
weight_dtype,
122122
device,
123123
noise_scheduler,
124+
vae_scale_factor,
124125
transformer,
125-
vae,
126126
optimizer,
127127
dataloader,
128128
args,
@@ -131,13 +131,13 @@ def __init__(
131131
self.device = device
132132
self.noise_scheduler = noise_scheduler
133133
self.transformer = transformer
134-
self.vae = vae
135134
self.optimizer = optimizer
136135
self.args = args
137136
self.mesh = xs.get_global_mesh()
138137
self.dataloader = iter(dataloader)
139138
self.global_step = 0
140139
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
140+
self.vae_scale_factor = vae_scale_factor
141141

142142
def run_optimizer(self):
143143
self.optimizer.step()
@@ -198,13 +198,7 @@ def step_fn(
198198
prompt_embeds = batch["prompt_embeds"]
199199
pooled_prompt_embeds = batch["pooled_prompt_embeds"]
200200
text_ids = batch["text_ids"]
201-
202-
pixel_tensor_values = batch["pixel_tensor_values"]
203-
model_input = self.vae.encode(pixel_tensor_values).latent_dist.sample()
204-
model_input = (model_input - self.vae.config.shift_factor) * self.vae.config.scaling_factor
205-
model_input = model_input.to(dtype=self.weight_dtype)
206-
207-
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
201+
model_input = batch["model_input"]
208202

209203
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
210204
model_input.shape[0],
@@ -264,9 +258,9 @@ def step_fn(
264258
# upscaling height & width as discussed in https://github.com/huggingface/diffusers/pull/9257#discussion_r1731108042
265259
model_pred = FluxPipeline._unpack_latents(
266260
model_pred,
267-
height=model_input.shape[2] * vae_scale_factor,
268-
width=model_input.shape[3] * vae_scale_factor,
269-
vae_scale_factor=vae_scale_factor,
261+
height=model_input.shape[2] * self.vae_scale_factor,
262+
width=model_input.shape[3] * self.vae_scale_factor,
263+
vae_scale_factor=self.vae_scale_factor,
270264
)
271265

272266
# these weighting schemes use a uniform timestep sampling
@@ -626,6 +620,17 @@ def encode_prompt(
626620

627621
return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds, "text_ids" : text_ids}
628622

623+
def compute_vae_encodings(batch, vae, device, dtype):
624+
images = batch.pop("pixel_values")
625+
pixel_values = torch.stack(list(images))
626+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
627+
pixel_values = pixel_values.to(vae.device, dtype=vae.dtype)
628+
629+
with torch.no_grad():
630+
model_input = vae.encode(pixel_values).latent_dist.sample()
631+
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
632+
return {"model_input": model_input}
633+
629634
def pixels_to_tensors(batch, device, dtype):
630635
images = batch.pop("pixel_values")
631636
pixel_values = torch.stack(list(images))
@@ -729,20 +734,20 @@ def main(args):
729734

730735
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
731736

732-
#unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward)
737+
#transformer = apply_xla_patch_to_nn_linear(transformer, xs.xla_patched_nn_linear_forward)
733738
transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True)
734739
FlashAttention.DEFAULT_BLOCK_SIZES = {
735-
"block_q": 1536,
736-
"block_k_major": 1536,
737-
"block_k": 1536,
738-
"block_b": 1536,
739-
"block_q_major_dkv": 1536,
740-
"block_k_major_dkv": 1536,
741-
"block_q_dkv": 1536,
742-
"block_k_dkv": 1536,
743-
"block_q_dq": 1536,
744-
"block_k_dq": 1536,
745-
"block_k_major_dq": 1536,
740+
"block_q": 512,
741+
"block_k_major": 512,
742+
"block_k": 512,
743+
"block_b": 512,
744+
"block_q_major_dkv": 512,
745+
"block_k_major_dkv": 512,
746+
"block_q_dkv": 512,
747+
"block_k_dkv": 512,
748+
"block_q_dq": 512,
749+
"block_k_dq": 768,
750+
"block_k_major_dq": 512,
746751
}
747752
# For mixed precision training we cast all non-trainable weights (vae,
748753
# non-lora text_encoder and non-lora unet) to half-precision
@@ -812,8 +817,7 @@ def preprocess_train(examples):
812817
tokenizers=tokenizers,
813818
caption_column=caption_column,
814819
)
815-
#compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
816-
pixels_to_tensors_fn = functools.partial(pixels_to_tensors, device=device, dtype=weight_dtype)
820+
compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae, device=device, dtype=weight_dtype)
817821
from datasets.fingerprint import Hasher
818822

819823
new_fingerprint = Hasher.hash(args)
@@ -822,24 +826,25 @@ def preprocess_train(examples):
822826
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
823827
)
824828
train_dataset_with_tensors = train_dataset.map(
825-
pixels_to_tensors_fn, batched=True, new_fingerprint=new_fingerprint_two, batch_size=256
829+
compute_vae_encodings_fn, batched=True, new_fingerprint=new_fingerprint_two, batch_size=8
826830
)
827831
precomputed_dataset = concatenate_datasets(
828832
[train_dataset_with_embeddings, train_dataset_with_tensors.remove_columns(["text", "image"])], axis=1
829833
)
830834
precomputed_dataset = precomputed_dataset.with_transform(preprocess_train)
831-
del compute_embeddings_fn, text_encoder, text_encoder_2
835+
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
836+
del compute_embeddings_fn, text_encoder, text_encoder_2, vae
832837
del text_encoders, tokenizers
833838
def collate_fn(examples):
834839
prompt_embeds = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples]).to(dtype=weight_dtype)
835840
pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples]).to(dtype=weight_dtype)
836841
text_ids = torch.stack([torch.tensor(example["text_ids"]) for example in examples]).to(dtype=weight_dtype)
837-
pixel_tensor_values = torch.stack([torch.tensor(example["pixel_tensor_values"]) for example in examples]).to(dtype=weight_dtype)
842+
model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples]).to(dtype=weight_dtype)
838843
return {
839844
"prompt_embeds": prompt_embeds,
840845
"pooled_prompt_embeds": pooled_prompt_embeds,
841846
"text_ids" : text_ids,
842-
"pixel_tensor_values" : pixel_tensor_values
847+
"model_input" : model_input,
843848
}
844849

845850
g = torch.Generator()
@@ -860,7 +865,7 @@ def collate_fn(examples):
860865
input_sharding={
861866
"prompt_embeds" : xs.ShardingSpec(mesh, ("data", None, None), minibatch=True),
862867
"pooled_prompt_embeds" : xs.ShardingSpec(mesh, ("data", None,), minibatch=True),
863-
"pixel_tensor_values" : xs.ShardingSpec(mesh, ("data", None, None, None), minibatch=True),
868+
"model_input" : xs.ShardingSpec(mesh, ("data", None, None, None), minibatch=True),
864869
"text_ids" : xs.ShardingSpec(mesh, ("data", None, None), minibatch=True),
865870
},
866871
loader_prefetch_size=args.loader_prefetch_size,
@@ -881,8 +886,8 @@ def collate_fn(examples):
881886
weight_dtype=weight_dtype,
882887
device=device,
883888
noise_scheduler=noise_scheduler,
889+
vae_scale_factor=vae_scale_factor,
884890
transformer=transformer,
885-
vae=vae,
886891
optimizer=optimizer,
887892
dataloader=train_dataloader,
888893
args=args,

0 commit comments

Comments
 (0)