Skip to content

Commit f82c152

Browse files
committed
fix prompt isolation test.
1 parent ec5449f commit f82c152

File tree

3 files changed

+39
-31
lines changed

3 files changed

+39
-31
lines changed

tests/pipelines/qwenimage/test_qwenimage_edit.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import unittest
1616

1717
import numpy as np
18-
import pytest
1918
import torch
2019
from PIL import Image
2120
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
@@ -238,6 +237,8 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
238237
"VAE tiling should not affect the inference results",
239238
)
240239

241-
@pytest.mark.xfail(condition=True, reason="Preconfigured embeddings need to be revisited.", strict=True)
242-
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4):
243-
super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol)
240+
def test_encode_prompt_works_in_isolation(
241+
self, extra_required_param_value_dict=None, keep_params=None, atol=1e-4, rtol=1e-4
242+
):
243+
keep_params = ["image"]
244+
super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, keep_params, atol, rtol)

tests/pipelines/qwenimage/test_qwenimage_edit_plus.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,11 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
236236
"VAE tiling should not affect the inference results",
237237
)
238238

239-
@pytest.mark.xfail(condition=True, reason="Preconfigured embeddings need to be revisited.", strict=True)
240-
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4):
241-
super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol)
239+
def test_encode_prompt_works_in_isolation(
240+
self, extra_required_param_value_dict=None, keep_params=None, atol=1e-4, rtol=1e-4
241+
):
242+
keep_params = ["image"]
243+
super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, keep_params, atol, rtol)
242244

243245
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
244246
def test_num_images_per_prompt():

tests/pipelines/test_pipelines_common.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import tempfile
66
import unittest
77
import uuid
8-
from typing import Any, Callable, Dict, Union
8+
from typing import Any, Callable, Dict, Optional, Union
99

1010
import numpy as np
1111
import PIL.Image
@@ -2069,20 +2069,26 @@ def test_loading_with_incorrect_variants_raises_error(self):
20692069

20702070
assert f"You are trying to load the model files of the `variant={variant}`" in str(error.exception)
20712071

2072-
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4):
2072+
def test_encode_prompt_works_in_isolation(
2073+
self,
2074+
extra_required_param_value_dict: Optional[dict] = None,
2075+
keep_params: Optional[list] = None,
2076+
atol=1e-4,
2077+
rtol=1e-4,
2078+
):
20732079
if not hasattr(self.pipeline_class, "encode_prompt"):
20742080
return
20752081

20762082
components = self.get_dummy_components()
20772083

2084+
def _contains_text_key(name):
2085+
return any(token in name for token in ("text", "tokenizer", "processor"))
2086+
20782087
# We initialize the pipeline with only text encoders and tokenizers,
20792088
# mimicking a real-world scenario.
2080-
components_with_text_encoders = {}
2081-
for k in components:
2082-
if "text" in k or "tokenizer" in k:
2083-
components_with_text_encoders[k] = components[k]
2084-
else:
2085-
components_with_text_encoders[k] = None
2089+
components_with_text_encoders = {
2090+
name: component if _contains_text_key(name) else None for name, component in components.items()
2091+
}
20862092
pipe_with_just_text_encoder = self.pipeline_class(**components_with_text_encoders)
20872093
pipe_with_just_text_encoder = pipe_with_just_text_encoder.to(torch_device)
20882094

@@ -2092,17 +2098,19 @@ def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=
20922098
encode_prompt_parameters = list(encode_prompt_signature.parameters.values())
20932099

20942100
# Required args in encode_prompt with those with no default.
2095-
required_params = []
2096-
for param in encode_prompt_parameters:
2097-
if param.name == "self" or param.name == "kwargs":
2098-
continue
2099-
if param.default is inspect.Parameter.empty:
2100-
required_params.append(param.name)
2101+
required_params = [
2102+
param.name
2103+
for param in encode_prompt_parameters
2104+
if param.name not in {"self", "kwargs"} and param.default is inspect.Parameter.empty
2105+
]
21012106

21022107
# Craft inputs for the `encode_prompt()` method to run in isolation.
21032108
encode_prompt_param_names = [p.name for p in encode_prompt_parameters if p.name != "self"]
2104-
input_keys = list(inputs.keys())
2105-
encode_prompt_inputs = {k: inputs.pop(k) for k in input_keys if k in encode_prompt_param_names}
2109+
encode_prompt_inputs = {name: inputs[name] for name in encode_prompt_param_names if name in inputs}
2110+
if keep_params:
2111+
for name in encode_prompt_param_names:
2112+
if name in inputs and name not in keep_params:
2113+
inputs.pop(name)
21062114

21072115
pipe_call_signature = inspect.signature(pipe_with_just_text_encoder.__call__)
21082116
pipe_call_parameters = pipe_call_signature.parameters
@@ -2137,18 +2145,15 @@ def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=
21372145

21382146
# Pack the outputs of `encode_prompt`.
21392147
adapted_prompt_embeds_kwargs = {
2140-
k: prompt_embeds_kwargs.pop(k) for k in list(prompt_embeds_kwargs.keys()) if k in pipe_call_parameters
2148+
name: prompt_embeds_kwargs[name] for name in prompt_embeds_kwargs if name in pipe_call_parameters
21412149
}
21422150

21432151
# now initialize a pipeline without text encoders and compute outputs with the
21442152
# `encode_prompt()` outputs and other relevant inputs.
2145-
components_with_text_encoders = {}
2146-
for k in components:
2147-
if "text" in k or "tokenizer" in k:
2148-
components_with_text_encoders[k] = None
2149-
else:
2150-
components_with_text_encoders[k] = components[k]
2151-
pipe_without_text_encoders = self.pipeline_class(**components_with_text_encoders).to(torch_device)
2153+
components_without_text_encoders = {
2154+
name: None if _contains_text_key(name) else component for name, component in components.items()
2155+
}
2156+
pipe_without_text_encoders = self.pipeline_class(**components_without_text_encoders).to(torch_device)
21522157

21532158
# Set `negative_prompt` to None as we have already calculated its embeds
21542159
# if it was present in `inputs`. This is because otherwise we will interfere wrongly

0 commit comments

Comments
 (0)