Skip to content

Conversation

@linoytsaban
Copy link
Collaborator

@linoytsaban linoytsaban commented Nov 29, 2024

Add the following to the flux redux prior pipeline:

  • prompt inputs
  • image interpolation

Group 2-20

inference example:
Group 3-19

image = Image.open("Self-portrait-oil-canvas-Thorn-Necklace-Hummingbird-Frida.jpg").convert("RGB")
image2 = Image.open("Mona_Lisa.jpg").convert("RGB")

pipe_prior_output = pipe_prior_redux([image,image2], prompt=["self portrait by frida khalo", "mona lisa"], 
                                     prompt_embeds_scale=[.9, .75],
                                     pooled_prompt_embeds_scale=[.6,1.25]
                                    )
pipe(
    guidance_scale=2.5,
    height=1024,
    width=1024,
    num_inference_steps=50,
    max_sequence_length=512,
    generator=torch.Generator("cpu").manual_seed(0),
    **pipe_prior_output,
).images[0]

tempImagemF6zdz 1

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Dec 1, 2024

let us know when it is ready for a review!
cc @asomoza

@linoytsaban linoytsaban marked this pull request as ready for review December 2, 2024 11:02
@linoytsaban
Copy link
Collaborator Author

@yiyixuxu @asomoza I think it's ready for review :)

Copy link
Contributor

@hlky hlky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works well, thanks! 🤗

Code
import torch
from diffusers import FluxPriorReduxPipeline, FluxPipeline
from diffusers.utils import load_image

device = "cuda"
dtype = torch.bfloat16
repo_redux = "black-forest-labs/FLUX.1-Redux-dev"
repo_base = "black-forest-labs/FLUX.1-dev"
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
  repo_redux, torch_dtype=dtype
).to(device)
pipe = FluxPipeline.from_pretrained(
  repo_base, text_encoder=None, text_encoder_2=None, torch_dtype=torch.bfloat16
).to(device)

image = load_image(
  "https://www.arthistoryproject.com/site/assets/files/19982/frida-kahlo-self-portrait-with-thorn-necklace-and-hummingbird-1940-trivium-art-history.jpg"
)
image2 = load_image(
  "https://upload.wikimedia.org/wikipedia/commons/thumb/6/6a/Mona_Lisa.jpg/1354px-Mona_Lisa.jpg"
)

pipe_prior_output = pipe_prior_redux(
  [image, image2],
  prompt=["self portrait by frida khalo", "mona lisa"],
  prompt_embeds_scale=[0.9, 0.75],
  pooled_prompt_embeds_scale=[0.6, 1.25],
)
images = pipe(
  guidance_scale=2.5,
  height=1024,
  width=1024,
  num_inference_steps=50,
  max_sequence_length=512,
  generator=torch.Generator("cpu").manual_seed(0),
  **pipe_prior_output,
).images[0]
images.save("flux-redux.png")
Output

flux-redux

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! this is really cool - would love to have it on doc somewhere too

prompt_2=None,
prompt_embeds=None,
pooled_prompt_embeds=None,
prompt=prompt,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's throw out a warning here:
if prompt inputs is passed but do not have text_encoder/tokenizer, in this case the text inputs will be ignored

it is a bit of different from our regular pipelines, normally, if you pass a prompt and do not have a text_encoder, you will get an error says like from encode_prompt; here we will just use zero prompt embeds instead, so let's be make an explicit warning about that

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, added one now

numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`
prompt (`str` or `List[str]`, *optional*):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make it clear that it is an experimental feature, and if you pass prompt, you will need to load text_encoders explicitly

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@sayakpaul
Copy link
Member

THIS NEEDS TO BE IN THE DOCS.

@stevhliu any ideas about the location?

@stevhliu
Copy link
Member

stevhliu commented Dec 3, 2024

Super cool! 🤩

We can add it to the "Specific Pipeline Examples" section and then build out the Flux doc there as discussed here.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! I think we can add doc in a separate PR
but need make style

@linoytsaban
Copy link
Collaborator Author

@yiyixuxu I'm not sure what's the issue, when I run make fixup it doesnt make any changes / flag any issues 🤔

@hlky
Copy link
Contributor

hlky commented Dec 4, 2024

@linoytsaban It's the doc-builder check, try doc-builder style src/diffusers docs/source --max_len 119 on its own or make style, it's not covered with make fixup

@linoytsaban
Copy link
Collaborator Author

thanks @hlky!

@yiyixuxu yiyixuxu merged commit 04bba38 into huggingface:main Dec 4, 2024
15 checks passed
@sayakpaul
Copy link
Member

@linoytsaban let's add some docs (#10056 (comment)) and communicate?

@Thekey756
Copy link

效果不太好,并没有达到官方示例的效果

@hlky
Copy link
Contributor

hlky commented Dec 5, 2024

@Thekey756 do you have an example comparison to the original?

@linoytsaban
Copy link
Collaborator Author

linoytsaban commented Dec 5, 2024

@Thekey756 I think to achieve the original effect you're referring to we need to also take in consideration - #10025

@linoytsaban linoytsaban deleted the redux branch December 5, 2024 14:16
@linoytsaban
Copy link
Collaborator Author

Example using prompts and attention masking for improved prompt adherence:
Group 4-17

pipe_prior_output = pipe_prior_redux([image], prompt=["anime illustration"], 
                                     prompt_embeds_scale=[1.],
                                     pooled_prompt_embeds_scale=[1.]
                                    )
cond_size = 729
hidden_size = 4096
max_sequence_length = 512
full_attention_size = max_sequence_length + hidden_size + cond_size
attention_mask = torch.zeros(
    (full_attention_size, full_attention_size), device="cuda", dtype=torch.bfloat16
)
reference_scale: float = 0.04 # example
bias = torch.log(
    torch.tensor(reference_scale, dtype=torch.bfloat16, device="cuda").clamp(min=1e-5, max=1)
)
attention_mask[:, max_sequence_length : max_sequence_length + cond_size] = bias
joint_attention_kwargs=dict(attention_mask=attention_mask)

pipe(
    guidance_scale=2.5,
    height=1024,
    width=1024,
    num_inference_steps=50,
    max_sequence_length=512,
    generator=torch.Generator("cpu").manual_seed(0),
    joint_attention_kwargs=joint_attention_kwargs,
    **pipe_prior_output,
).images[0]

@hlky hlky mentioned this pull request Dec 12, 2024
6 tasks
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* add multiple prompts to flux redux

---------

Co-authored-by: hlky <[email protected]>
@lhjlhj11
Copy link

Add the following to the flux redux prior pipeline:

* prompt inputs

* image interpolation

Group 2-20

inference example: Group 3-19

image = Image.open("Self-portrait-oil-canvas-Thorn-Necklace-Hummingbird-Frida.jpg").convert("RGB")
image2 = Image.open("Mona_Lisa.jpg").convert("RGB")

pipe_prior_output = pipe_prior_redux([image,image2], prompt=["self portrait by frida khalo", "mona lisa"], 
                                     prompt_embeds_scale=[.9, .75],
                                     pooled_prompt_embeds_scale=[.6,1.25]
                                    )
pipe(
    guidance_scale=2.5,
    height=1024,
    width=1024,
    num_inference_steps=50,
    max_sequence_length=512,
    generator=torch.Generator("cpu").manual_seed(0),
    **pipe_prior_output,
).images[0]

tempImagemF6zdz 1

Can you give an example for your Fast FLUX.1 Redux in the space of huggingface? Thanks! And what is the masking scale?

@MikeHanKK
Copy link

how should I set prompt_embeds_scale, pooled_prompt_embeds_scale, and guidance_scale in general? @linoytsaban

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants