Skip to content

Add RAE Diffusion Transformer inference/preliminary training pipelines#13231

Open
plugyawn wants to merge 13 commits intohuggingface:mainfrom
plugyawn:rae-dit-training
Open

Add RAE Diffusion Transformer inference/preliminary training pipelines#13231
plugyawn wants to merge 13 commits intohuggingface:mainfrom
plugyawn:rae-dit-training

Conversation

@plugyawn
Copy link

@plugyawn plugyawn commented Mar 9, 2026

What does this PR do?

This PR adds support for Diffusion Transformers with Representation Autoencoders in Diffusers.

It implements the Stage-2 side of the RAE setup:

  • RAEDiT2DModel
  • RAEDiTPipeline
  • checkpoint conversion for published upstream Stage-2 checkpoints
  • API docs
  • a small examples/research_projects/rae_dit/ training scaffold

This addresses #13225.

Reference implementation: byteriper's repository

Validation

Inference parity with the official implementation is high. For matched class label / initial latent noise / schedule, I measured:

  • max_abs_error=0.00001717
  • mean_abs_error=0.00000122

Qualitative parity artifacts used during validation:

  • same published Stage-2 checkpoint
  • same class label
  • same initial latent noise
  • same 25-step shifted Euler schedule

Inference is also slightly faster in the current Diffusers port on a 40GB A100:

Precision CFG Steps Diffusers sec/img Upstream sec/img Diffusers img/s Delta
bf16 1.0 25 0.817 0.913 1.225 +11.8%
bf16 4.0 25 0.852 0.931 1.174 +9.3%
bf16 1.0 50 1.568 1.761 0.638 +12.3%
bf16 4.0 50 1.649 1.853 0.606 +12.4%

Notes

  • This PR intentionally does not add upstream autoguidance / guidance-model support.
  • The training script is a research-project scaffold under examples/research_projects, not a claim of full upstream training parity.
  • AutoencoderRAE.from_pretrained() is used for the Stage-1 component so the packaged RAEDiTPipeline.from_pretrained(...) path works with published RAE checkpoints.

Before submitting

@plugyawn plugyawn changed the title Add Stage-2 RAE DiT support with pipeline, conversion, and training tooling RAE DiT inference, checkpoint conversion, and preliminary training tooling Mar 9, 2026
@plugyawn plugyawn changed the title RAE DiT inference, checkpoint conversion, and preliminary training tooling Add RAE Diffusion Transformer inference/preliminary training pipelines Mar 9, 2026
@plugyawn plugyawn marked this pull request as draft March 9, 2026 05:46
@plugyawn plugyawn marked this pull request as ready for review March 9, 2026 05:51
@plugyawn
Copy link
Author

plugyawn commented Mar 9, 2026

@kashif @sayakpaul would be great if you could review. Please note the no_init_weights() fix (details in the PR body); if you prefer, that could be a separate PR, but considering diffusers is supposed to be an extension to torch, I guess it makes sense?

@sayakpaul
Copy link
Member

Thanks for the PR. To keep the scope manageable, could we break it down into separate PRs?

For example,

there is also a change to no_init_weights( ). Specifically: it makes Diffusers’ skip-weight-init behave more like normal PyTorch. Now, when no_init_weights() is active, the torch.nn.init.* functions stop returning the tensor they were called on (for ref: PyTorch does return). Most models never notice this, but the RAE-DiT implementation does rely on the return value during construction, which can make otherwise valid checkpoints fail to load through the standard from_pretrained() path.

could be a separate PR.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

I left some initial comments, let me know if they make sense.

Comment on lines +13 to +16
- `examples/dreambooth/train_dreambooth_flux.py`
for the flow-matching training loop structure, checkpoint resume flow, and `accelerate.save_state(...)` hooks.
- `examples/flux-control/train_control_flux.py`
for the transformer-only save layout and SD3-style flow-matching timestep weighting helpers.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't belong here.

Comment on lines +218 to +221
# Preserve the `torch.nn.init.*` return contract so third-party model
# constructors that chain on the returned tensor still work under
# `no_init_weights()`.
return args[0] if len(args) > 0 else None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide an example?

super().test_effective_gradient_checkpointing(loss_tolerance=1e-4)

@unittest.skip(
"RAEDiT initializes the output head to zeros, so cosine-based layerwise casting checks are uninformative."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the case? We can always skip layerwise casting for certain layer or layer groups here:

_skip_layerwise_casting_patterns = None

model.final_layer.linear.bias.data.normal_(mean=0.0, std=0.02)


class RAEDiT2DModelTests(ModelTesterMixin, unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test should use the newly added model tester mixins. You can find an example in #13046

Comment on lines +48 to +49
if shift is None:
shift = torch.zeros_like(scale)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a small function, which is okay being present in the caller sites inline?

We also probably don't need _repeat_to_length().

Comment on lines +466 to +470
if self.use_pos_embed:
pos_embed = get_2d_sincos_pos_embed(
self.pos_embed.shape[-1], int(sqrt(self.pos_embed.shape[1])), output_type="pt"
)
self.pos_embed.data.copy_(pos_embed.float().unsqueeze(0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use how #13046 initialized the position embeddings?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that makes sense, will do that.

)
return hidden_states

def _run_block(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this. Let's instead follow this pattern:

for index_block, block in enumerate(self.transformer_blocks):


return class_labels

def _prepare_latents(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be called prepare_latents() similar to other pipelines.

Comment on lines +247 to +252
if output_type == "pt":
output = images
else:
output = images.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
output = self.numpy_to_pil(output)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use an image processor instead here. See:

image = self.image_processor.postprocess(image, output_type=output_type)

if not return_dict:
return (output,)

return ImagePipelineOutput(images=output)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's give this pipeline a separate output class: RAEDiTPipelineOutput.

@sayakpaul sayakpaul requested review from dg845 and kashif March 9, 2026 11:33
@plugyawn
Copy link
Author

plugyawn commented Mar 10, 2026

@sayakpaul, from what I understand the RAE checkpoint -> DiT checkpoint -> generation pipeline necessarily requires the no_init_weight() change (otherwise the semantics become a bit muddled, imo).

Would it make more sense to open a PR for handling no_init_weights() behavior before this one?

@sayakpaul
Copy link
Member

Could you explain why that's needed? I am still not sure about that actually. Prefer providing specific examples that fail without the change for init.

@plugyawn
Copy link
Author

plugyawn commented Mar 10, 2026

Not sure how to link files, but it seems to be related to changes introduced in #13046.

A specific example,

  • AutoencoderRAE consturcts DinoV2WithRegistersModel.
  • ModelMixin.from_pretrained() does this construction under no_init_weights( ) first, before low_cpu_mem_usage kicks in (modelling_utils.py, around line 1300)
  • AutoencoderRAE constructs Dinov2WithRegistersModel(config) in _build_encoder:84, and
    ModelMixin.from_pretrained() always does that construction under no_init_weights() first, even
    before low_cpu_mem_usage matters; see modeling_utils.py:1270. In current transformers, DINOv2-
    with-registers has init code like this in modeling_dinov2_with_registers.py:464:
  module.weight.data = nn.init.trunc_normal_(
      module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
  ).to(module.weight.dtype)

Under today’s no_init_weights(), nn.init.trunc_normal_ is replaced with a stub that just passes
and returns None, so that becomes None.to(...) and fails with an AttributeError: 'NoneType' object has no attribute 'to'.

Codex has a better summary, I think:

failing example: AutoencoderRAE builds Dinov2WithRegistersModel(config) in its encoder
path, and ModelMixin.from_pretrained() always instantiates models under no_init_weights() first.
In current transformers, DINOv2’s init_weights() assigns the return value of
nn.init.trunc_normal
(...) and then calls .to(...) on it. With the current no_init_weights()
stub, that return value becomes None, so construction fails with AttributeError: 'NoneType'
object has no attribute 'to'. The proposed change keeps skip-init behavior intact, but restores
the normal PyTorch return contract so these constructors remain compatible.

Re: #13046, note test_models_autoencoder_rae.py:45, where the unit tests seem to be a little off, imo. Not sure the tests are aligned.

# ---------------------------------------------------------------------------
# Tiny test encoder for fast unit tests (no transformers dependency)
# ---------------------------------------------------------------------------


class _TinyTestEncoderModule(torch.nn.Module):
    """Minimal encoder that mimics the patch-token interface without any HF model."""

    def __init__(self, hidden_size: int = 16, patch_size: int = 8, **kwargs):
        super().__init__()
        self.patch_size = patch_size
        self.hidden_size = hidden_size

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        pooled = F.avg_pool2d(images.mean(dim=1, keepdim=True), kernel_size=self.patch_size, stride=self.patch_size)
        tokens = pooled.flatten(2).transpose(1, 2).contiguous()
        return tokens.repeat(1, 1, self.hidden_size)


def _tiny_test_encoder_forward(model, images):
    return model(images)


def _build_tiny_test_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers):
    return _TinyTestEncoderModule(hidden_size=hidden_size, patch_size=patch_size)


# Monkey-patch the dispatch tables so "tiny_test" is recognised by AutoencoderRAE
_ENCODER_FORWARD_FNS["tiny_test"] = _tiny_test_encoder_forward
_original_build_encoder = _build_encoder


def _patched_build_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers):
    if encoder_type == "tiny_test":
        return _build_tiny_test_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers)
    return _original_build_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers)


_rae_module._build_encoder = _patched_build_encoder

I'm new to diffusers idiomatics, but I was confused why this appeared to be a problem only now, and asked GPT:

no_init_weights() only becomes a problem when all of these are true at once:

  • a diffusers ModelMixin.from_pretrained() call is constructing the model
  • that model’s init() instantiates another model internally
  • that internal model uses torch.nn.init.* and also relies on its return value

RAE is unusual because it does exactly that. Inside autoencoder_rae.py, the AutoencoderRAE constructor directly >builds a transformers vision backbone:

  • Dinov2WithRegistersModel:98
  • SiglipVisionModel:111
  • ViTMAEModel:124

That is not how most other diffusers integrations are structured. Most of the repo does one of these instead:

  • native diffusers models in src/diffusers/models, whose init code only relies on side effects
  • pipelines that accept transformers models as separate top-level components, rather than constructing them inside > a ModelMixin

So other work usually does not run a transformers constructor inside diffusers’ patched no_init_weights() context.

@sayakpaul
Copy link
Member

Not sure how to link files

Yes, we can link files and I think it's better this way. For example, it's much better to refer to specific lines like

def get_parameter_device(parameter: torch.nn.Module) -> torch.device:

instead of plain text.

Overall, I think that the explanation you provided in the above comment is that helpful. We need to have some specific (preferably very minimal) code snippet with and without that change to better understand what's happening and why.

For this kind of PRs, it's an expectation that the contributors will try to take some time to understand the library code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants