Add RAE Diffusion Transformer inference/preliminary training pipelines#13231
Add RAE Diffusion Transformer inference/preliminary training pipelines#13231plugyawn wants to merge 13 commits intohuggingface:mainfrom
Conversation
|
@kashif @sayakpaul would be great if you could review. Please note the no_init_weights() fix (details in the PR body); if you prefer, that could be a separate PR, but considering |
|
Thanks for the PR. To keep the scope manageable, could we break it down into separate PRs? For example,
could be a separate PR. |
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks!
I left some initial comments, let me know if they make sense.
| - `examples/dreambooth/train_dreambooth_flux.py` | ||
| for the flow-matching training loop structure, checkpoint resume flow, and `accelerate.save_state(...)` hooks. | ||
| - `examples/flux-control/train_control_flux.py` | ||
| for the transformer-only save layout and SD3-style flow-matching timestep weighting helpers. |
| # Preserve the `torch.nn.init.*` return contract so third-party model | ||
| # constructors that chain on the returned tensor still work under | ||
| # `no_init_weights()`. | ||
| return args[0] if len(args) > 0 else None |
| super().test_effective_gradient_checkpointing(loss_tolerance=1e-4) | ||
|
|
||
| @unittest.skip( | ||
| "RAEDiT initializes the output head to zeros, so cosine-based layerwise casting checks are uninformative." |
There was a problem hiding this comment.
I don't think this is the case? We can always skip layerwise casting for certain layer or layer groups here:
| model.final_layer.linear.bias.data.normal_(mean=0.0, std=0.02) | ||
|
|
||
|
|
||
| class RAEDiT2DModelTests(ModelTesterMixin, unittest.TestCase): |
There was a problem hiding this comment.
Test should use the newly added model tester mixins. You can find an example in #13046
| if shift is None: | ||
| shift = torch.zeros_like(scale) |
There was a problem hiding this comment.
This is a small function, which is okay being present in the caller sites inline?
We also probably don't need _repeat_to_length().
| if self.use_pos_embed: | ||
| pos_embed = get_2d_sincos_pos_embed( | ||
| self.pos_embed.shape[-1], int(sqrt(self.pos_embed.shape[1])), output_type="pt" | ||
| ) | ||
| self.pos_embed.data.copy_(pos_embed.float().unsqueeze(0)) |
There was a problem hiding this comment.
Can we use how #13046 initialized the position embeddings?
There was a problem hiding this comment.
Yeah, that makes sense, will do that.
| ) | ||
| return hidden_states | ||
|
|
||
| def _run_block( |
There was a problem hiding this comment.
We don't need this. Let's instead follow this pattern:
|
|
||
| return class_labels | ||
|
|
||
| def _prepare_latents( |
There was a problem hiding this comment.
It should be called prepare_latents() similar to other pipelines.
| if output_type == "pt": | ||
| output = images | ||
| else: | ||
| output = images.cpu().permute(0, 2, 3, 1).float().numpy() | ||
| if output_type == "pil": | ||
| output = self.numpy_to_pil(output) |
There was a problem hiding this comment.
We should use an image processor instead here. See:
| if not return_dict: | ||
| return (output,) | ||
|
|
||
| return ImagePipelineOutput(images=output) |
There was a problem hiding this comment.
Let's give this pipeline a separate output class: RAEDiTPipelineOutput.
|
@sayakpaul, from what I understand the RAE checkpoint -> DiT checkpoint -> generation pipeline necessarily requires the no_init_weight() change (otherwise the semantics become a bit muddled, imo). Would it make more sense to open a PR for handling |
|
Could you explain why that's needed? I am still not sure about that actually. Prefer providing specific examples that fail without the change for init. |
|
Not sure how to link files, but it seems to be related to changes introduced in #13046. A specific example,
Under today’s no_init_weights(), nn.init.trunc_normal_ is replaced with a stub that just passes Codex has a better summary, I think:
Re: #13046, note I'm new to diffusers idiomatics, but I was confused why this appeared to be a problem only now, and asked GPT:
|
Yes, we can link files and I think it's better this way. For example, it's much better to refer to specific lines like instead of plain text. Overall, I think that the explanation you provided in the above comment is that helpful. We need to have some specific (preferably very minimal) code snippet with and without that change to better understand what's happening and why. For this kind of PRs, it's an expectation that the contributors will try to take some time to understand the library code. |
What does this PR do?
This PR adds support for Diffusion Transformers with Representation Autoencoders in Diffusers.
It implements the Stage-2 side of the RAE setup:
RAEDiT2DModelRAEDiTPipelineexamples/research_projects/rae_dit/training scaffoldThis addresses #13225.
Reference implementation: byteriper's repository
Validation
Inference parity with the official implementation is high. For matched class label / initial latent noise / schedule, I measured:
max_abs_error=0.00001717mean_abs_error=0.00000122Qualitative parity artifacts used during validation:
Inference is also slightly faster in the current Diffusers port on a 40GB A100:
Notes
examples/research_projects, not a claim of full upstream training parity.AutoencoderRAE.from_pretrained()is used for the Stage-1 component so the packagedRAEDiTPipeline.from_pretrained(...)path works with published RAE checkpoints.Before submitting
documentation guidelines, and
here are tips on formatting docstrings.