Skip to content

Commit 49e683f

Browse files
committed
start to work on edit
1 parent ff06e95 commit 49e683f

File tree

1 file changed

+189
-7
lines changed

1 file changed

+189
-7
lines changed

src/diffusers/modular_pipelines/qwenimage/encoders.py

Lines changed: 189 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,18 @@
2525

2626
from .modular_pipeline import QwenImageModularPipeline
2727

28+
from ...pipelines.qwenimage.pipeline_qwenimage import calculate_dimensions
29+
2830
logger = logging.get_logger(__name__)
2931

3032

33+
def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor):
34+
bool_mask = mask.bool()
35+
valid_lengths = bool_mask.sum(dim=1)
36+
selected = hidden_states[bool_mask]
37+
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
38+
return split_result
39+
3140
def get_qwen_prompt_embeds(
3241
text_encoder,
3342
tokenizer,
@@ -53,13 +62,6 @@ def get_qwen_prompt_embeds(
5362
)
5463
hidden_states = encoder_hidden_states.hidden_states[-1]
5564

56-
def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor):
57-
bool_mask = mask.bool()
58-
valid_lengths = bool_mask.sum(dim=1)
59-
selected = hidden_states[bool_mask]
60-
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
61-
return split_result
62-
6365
split_hidden_states = _extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
6466
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
6567
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
@@ -75,6 +77,55 @@ def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor):
7577

7678
return prompt_embeds, encoder_attention_mask
7779

80+
81+
def get_qwen_prompt_embeds_edit(
82+
text_encoder,
83+
processor,
84+
prompt: Union[str, List[str]] = None,
85+
image: Optional[torch.Tensor] = None,
86+
prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n",
87+
prompt_template_encode_start_idx: int = 64,
88+
device: Optional[torch.device] = None,
89+
dtype: Optional[torch.dtype] = None,
90+
):
91+
92+
prompt = [prompt] if isinstance(prompt, str) else prompt
93+
94+
template = prompt_template_encode
95+
drop_idx = prompt_template_encode_start_idx
96+
txt = [template.format(e) for e in prompt]
97+
98+
model_inputs = processor(
99+
text=txt,
100+
images=image,
101+
padding=True,
102+
return_tensors="pt",
103+
).to(device)
104+
105+
outputs = text_encoder(
106+
input_ids=model_inputs.input_ids,
107+
attention_mask=model_inputs.attention_mask,
108+
pixel_values=model_inputs.pixel_values,
109+
image_grid_thw=model_inputs.image_grid_thw,
110+
output_hidden_states=True,
111+
)
112+
113+
hidden_states = outputs.hidden_states[-1]
114+
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
115+
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
116+
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
117+
max_seq_len = max([e.size(0) for e in split_hidden_states])
118+
prompt_embeds = torch.stack(
119+
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
120+
)
121+
encoder_attention_mask = torch.stack(
122+
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
123+
)
124+
125+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
126+
127+
return prompt_embeds, encoder_attention_mask
128+
78129
class QwenImageTextEncoderStep(ModularPipelineBlocks):
79130
model_name = "qwenimage"
80131

@@ -139,6 +190,137 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
139190
device = components._execution_device
140191
self.check_inputs(block_state.prompt, block_state.negative_prompt, block_state.max_sequence_length)
141192

193+
block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds(
194+
components.text_encoder,
195+
components.tokenizer,
196+
prompt=block_state.prompt,
197+
prompt_template_encode=components.config.prompt_template_encode,
198+
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
199+
tokenizer_max_length=components.config.tokenizer_max_length,
200+
device=device,
201+
)
202+
203+
block_state.prompt_embeds = block_state.prompt_embeds[:, :block_state.max_sequence_length]
204+
block_state.prompt_embeds_mask = block_state.prompt_embeds_mask[:, :block_state.max_sequence_length]
205+
206+
if components.requires_unconditional_embeds:
207+
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds(
208+
components.text_encoder,
209+
components.tokenizer,
210+
prompt=block_state.negative_prompt,
211+
prompt_template_encode=components.config.prompt_template_encode,
212+
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
213+
tokenizer_max_length=components.config.tokenizer_max_length,
214+
device=device,
215+
)
216+
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds[:, :block_state.max_sequence_length]
217+
block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask[:, :block_state.max_sequence_length]
218+
219+
self.set_block_state(state, block_state)
220+
return components, state
221+
222+
223+
class QwenImageImageResizeStep(ModularPipelineBlocks):
224+
model_name = "qwenimage"
225+
226+
@property
227+
def description(self) -> str:
228+
return "Image Resize step that resize the image to the target area while maintaining the aspect ratio"
229+
230+
@property
231+
def expected_components(self) -> List[ComponentSpec]:
232+
return [
233+
ComponentSpec("image_processor", VaeImageProcessor, config=FrozenDict({"vae_scale_factor": 16}), default_creation_method="from_config"),
234+
]
235+
236+
@property
237+
def inputs(self) -> List[InputParam]:
238+
return [
239+
InputParam(name="image", required=True, type_hint=torch.Tensor, description="The image to resize"),
240+
]
241+
242+
@torch.no_grad()
243+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
244+
block_state = self.get_block_state(state)
245+
246+
247+
if not isinstance(block_state.image, list):
248+
block_state.image = [block_state.image]
249+
250+
image_width, image_height = block_state.image[0].size
251+
calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_width / image_height)
252+
253+
block_state.image = components.image_processor.resize(image, (calculated_height, calculated_width))
254+
self.set_block_state(state, block_state)
255+
return components, state
256+
257+
258+
class QwenImageEditTextEncoderStep(ModularPipelineBlocks):
259+
model_name = "qwenimage"
260+
261+
@property
262+
def description(self) -> str:
263+
return "Text Encoder step that generate text_embeddings to guide the image generation"
264+
265+
@property
266+
def expected_components(self) -> List[ComponentSpec]:
267+
return [
268+
ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration),
269+
ComponentSpec("processor", Qwen2VLProcessor),
270+
ComponentSpec(
271+
"guider",
272+
ClassifierFreeGuidance,
273+
config=FrozenDict({"guidance_scale": 4.0}),
274+
default_creation_method="from_config",
275+
),
276+
]
277+
278+
@property
279+
def expected_configs(self) -> List[ConfigSpec]:
280+
return [
281+
ConfigSpec(name="prompt_template_encode", default="<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"),
282+
ConfigSpec(name="prompt_template_encode_start_idx", default=64),
283+
]
284+
285+
@property
286+
def inputs(self) -> List[InputParam]:
287+
return [
288+
InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"),
289+
InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"),
290+
InputParam(name="max_sequence_length", type_hint=int, description="The max sequence length to use", default=1024),
291+
]
292+
293+
@property
294+
def intermediate_outputs(self) -> List[OutputParam]:
295+
return [
296+
OutputParam(name="prompt_embeds", kwargs_type="guider_input_fields",type_hint=torch.Tensor, description="The prompt embeddings"),
297+
OutputParam(name="prompt_embeds_mask", kwargs_type="guider_input_fields", type_hint=torch.Tensor, description="The encoder attention mask"),
298+
OutputParam(name="negative_prompt_embeds", kwargs_type="guider_input_fields", type_hint=torch.Tensor, description="The negative prompt embeddings"),
299+
OutputParam(name="negative_prompt_embeds_mask", kwargs_type="guider_input_fields", type_hint=torch.Tensor, description="The negative prompt embeddings mask"),
300+
]
301+
302+
@staticmethod
303+
def check_inputs(prompt, negative_prompt, max_sequence_length):
304+
305+
if not isinstance(prompt, str) and not isinstance(prompt, list):
306+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
307+
308+
if negative_prompt is not None and not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list):
309+
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
310+
311+
if max_sequence_length is not None and max_sequence_length > 1024:
312+
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
313+
314+
@torch.no_grad()
315+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
316+
block_state = self.get_block_state(state)
317+
318+
self.check_inputs(block_state.prompt, block_state.negative_prompt, block_state.max_sequence_length)
319+
320+
device = components._execution_device
321+
image = components.image_processor.preprocess(block_state.image)
322+
323+
142324
block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds(
143325
components.text_encoder,
144326
components.tokenizer,

0 commit comments

Comments
 (0)