Skip to content

Conversation

@scxue
Copy link
Collaborator

@scxue scxue commented Apr 7, 2025

Initial implementation of SANA-Sprint training script adapted for Diffusers.
This needs further refinement and optimization. @lawrence-cj @sayakpaul

@sayakpaul
Copy link
Contributor

Will review in a bit.

Copy link
Contributor

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking really promising. I left some comments, LMK if they make sense.

Additionally, if we could wrap the loss computations for the different phases into different functions, I think that will be easier to read. LMK what you think.

@@ -0,0 +1,1823 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to add SANA Sprint team here too :)

if is_torch_npu_available():
torch.npu.config.allow_internal_format = False

complex_human_instruction = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
complex_human_instruction = [
COMPLEX_HUMAN_INSTRUCTION = [

return False


class Text2ImageDataset:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have an example dataset with which it would work?

)
# add meta-data to dataloader instance for convenience
self._train_dataloader.num_batches = num_batches
self._train_dataloader.num_samples = num_samples
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could use num_train_examples here no?

disc.eval()
models_to_accumulate = [transformer]
with accelerator.accumulate(models_to_accumulate):
with torch.no_grad():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can then remove this context manager.

images = None
del pipeline

# Save the lora layers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not doing LoRA. So, this can be safely omitted.

cfg_y = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
cfg_y_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)

cfg_pretrain_pred = pretrained_model(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As another optimization, we could keep the pretrained_model in CPU once this computation is done and load to GPU again when needed.

phase = "G"

optimizer_D.step()
optimizer_D.zero_grad(set_to_none=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think set_to_none is by default True.

lr_scheduler.step()
optimizer_G.zero_grad(set_to_none=True)

elif phase == "D":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this alternates between two phases in the same training step, right? If so, I would add a comment.

Also, should we let the users control the step interval in which the discriminator should be updated? Or not really?

@scxue
Copy link
Collaborator Author

scxue commented Apr 9, 2025

Thanks for your thorough review and helpful suggestions! I'll carefully go through them and incorporate the changes when I'm back. Really appreciate it!

@sayakpaul
Copy link
Contributor

Please don't hesitate to ping me for running tests, etc.

@lawrence-cj
Copy link
Collaborator

Adding here:

huggingface/diffusers#11514

@lawrence-cj
Copy link
Collaborator

Let's duplicate the files in diffusers here. @scxue
Refer to: huggingface/diffusers#11514

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants