Skip to content

Conversation

@leffff
Copy link
Contributor

@leffff leffff commented Oct 13, 2025

What does this PR do?

This PR adds Kandinsky5T2VPipeline and Kandinsky5Transformer3DModel as well as several layer classes neede for Kandinsky 5.0 Lite T2V model

@sayakpaul Please review

@sayakpaul sayakpaul requested review from DN6 and yiyixuxu October 14, 2025 04:02
@sayakpaul
Copy link
Member

Could you please update the PR with test code and some example outputs?

@leffff
Copy link
Contributor Author

leffff commented Oct 14, 2025

Sure!

@leffff
Copy link
Contributor Author

leffff commented Oct 14, 2025

@leffff
Copy link
Contributor Author

leffff commented Oct 14, 2025

Dear @sayakpaul @yiyixuxu @DN6
How should the test code and example outputs look like?

@leffff
Copy link
Contributor Author

leffff commented Oct 14, 2025

import torch
from diffusers import Kandinsky5T2VPipeline
from diffusers.utils import export_to_video

pipe = Kandinsky5T2VPipeline.from_pretrained(
    "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers", 
    torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")

negative_prompt = [
    "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards",
]
prompt = [
    "A cat and a dog baking a cake together in a kitchen.",
]

output = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=512,
    width=768,
    num_frames=121,
    num_inference_steps=50,
    guidance_scale=5.0,
    num_videos_per_prompt=1,
    generator=torch.Generator(42)
)
output.10.mp4
prompt = [
    "A monkey ridign a skateboard",
]
output.10.mp4
prompt = [
    "Several giant wooly mammoths threading through the meadow",
]
output.10.mp4

@sayakpaul
Copy link
Member

Great, thanks for providing the examples! Does the model also do realistic generations? 👀

@linoytsaban @apolinario @asomoza in case you wanna test it?

@leffff
Copy link
Contributor Author

leffff commented Oct 14, 2025

Yes of course!

A stylish woman struts confidently down a rain-drenched Tokyo street, where vibrant neon signs flicker and pulse with electric color. She wears a sleek black leather jacket over a flowing red dress, paired with polished black boots and a matching black purse. Her sunglasses reflect the glowing cityscape as she moves with a calm, assured demeanor, red lipstick adding a bold contrast to her look. The wet pavement mirrors the dazzling lights, doubling the intensity of the urban glow around her. Pedestrians bustle along the sidewalks, their silhouettes blending into the dynamic, cinematic atmosphere of the neon-lit metropolis.

output.10.mp4

A cinematic movie trailer unfolds with a 30-year-old space man traversing a vast salt desert beneath a brilliant blue sky. He wears a uniquely styled red wool knitted motorcycle helmet, adding an eccentric yet rugged charm to his spacefaring look. As he rides a retro-futuristic vehicle across the shimmering white terrain, the wind kicks up clouds of glittering salt, creating a surreal atmosphere. The scene is captured in a vivid, cinematic style, shot on 35mm film to enhance the nostalgic and dramatic grain. Explosions of color and dynamic camera movements highlight the space man's daring escape from a collapsing alien base in the distance.

output.11.mp4

Copy link
Member

@asomoza asomoza left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, looks cool! left some suggestions for unused imports

@leffff
Copy link
Contributor Author

leffff commented Oct 17, 2025

@yiyixuxu
Done! All your fixes are added! Ready to merge!

@asomoza
Copy link
Member

asomoza commented Oct 17, 2025

@leffff just want to let you know that I've been testing the 10s model and I'm really impressed with it, I like it a lot, congrats to the team. Can't wait for when you release the I2V one.

kangaroo.mp4

@leffff
Copy link
Contributor Author

leffff commented Oct 17, 2025

@asomoza Great! Gonna add them in the next iteration!

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will merge once CI is green!

@yiyixuxu yiyixuxu merged commit 23ebbb4 into huggingface:main Oct 18, 2025
28 of 31 checks passed
@leffff
Copy link
Contributor Author

leffff commented Oct 18, 2025

Hurrah!!!

@yiyixuxu
Copy link
Collaborator

@leffff
look forward to the follow-up PR for the 10s model!
We are very happy to help too - let me know if you need anything :)

@leffff
Copy link
Contributor Author

leffff commented Oct 20, 2025

Hi!
@yiyixuxu how can we make Kandinsky 5 appear here: https://huggingface.co/docs/diffusers/api/pipelines/overview?

@sayakpaul
Copy link
Member

You need to add a page like: https://github.com/huggingface/diffusers/blob/main/docs/source/en/api/pipelines/kandinsky.md

@leffff
Copy link
Contributor Author

leffff commented Oct 20, 2025

Great! Thanks!

@leffff
Copy link
Contributor Author

leffff commented Oct 21, 2025

Just commenting to note that we support all kinds of different attention backends now. So, as long as we implement the attention class in this way, for example, swapping a backend from SDPA ("native" in our terminology) to "flex" should be very easy.

model.set_attention_backend("flex")

Yes, you are right. I tried doing

pipe.transformer.set_attention_backend("flex")

and it almost worked. You see, when I made separate processors, I did this:

class Kandinsky5NablaAttentionProcessor(nn.Module):
    """Custom attention processor for Nabla attention"""
    
    @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
    def __call__(
        self,
        attn,
        query,
        key,
        value,
        sparse_params=None,
        **kwargs,
    ):
        if sparse_params is None:
            raise ValueError("sparse_params is required for Nabla attention")

        query = query.transpose(1, 2).contiguous()
        key = key.transpose(1, 2).contiguous()
        value = value.transpose(1, 2).contiguous()

        block_mask = nablaT_v2(
            query,
            key,
            sparse_params["sta_mask"],
            thr=sparse_params["P"],
        )
        out = (
            flex_attention(query, key, value, block_mask=block_mask)
            .transpose(1, 2)
            .contiguous()
        )
        out = out.flatten(-2, -1)
        return out

@sayakpaul
Copy link
Member

and it almost worked.

What do you mean? It didn't work as expected or are we good? 👀

@leffff
Copy link
Contributor Author

leffff commented Oct 21, 2025

It worked as expected, yet it's not everything. Flex requires additional compilation. Please see #12520

@sayakpaul
Copy link
Member

I will reply to that PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants