Skip to content

Commit bfa253a

Browse files
committed
poc encode_prompt() tests
1 parent c28db0a commit bfa253a

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

tests/pipelines/flux/test_pipeline_flux.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapte
2727
pipeline_class = FluxPipeline
2828
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
2929
batch_params = frozenset(["prompt"])
30+
prompt_embed_kwargs = ("prompt_embeds", "pooled_prompt_embeds", "text_ids")
3031

3132
# there is no xformers processor for Flux
3233
test_xformers_attention = False

tests/pipelines/test_pipelines_common.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,8 @@ class PipelineTesterMixin:
986986

987987
test_xformers_attention = True
988988

989+
prompt_embed_kwargs = ("prompt_embeds", "negative_prompt_embeds") # most common return-type across the pipelines.
990+
989991
def get_generator(self, seed):
990992
device = torch_device if torch_device != "mps" else "cpu"
991993
generator = torch.Generator(device).manual_seed(seed)
@@ -1976,6 +1978,62 @@ def test_loading_with_incorrect_variants_raises_error(self):
19761978

19771979
assert f"You are trying to load the model files of the `variant={variant}`" in str(error.exception)
19781980

1981+
def test_encode_prompt_works_in_isolation(self):
1982+
if not hasattr(self.pipeline_class, "encode_prompt"):
1983+
return
1984+
components = self.get_dummy_components()
1985+
pipe = self.pipeline_class(**components)
1986+
pipe = pipe.to(torch_device)
1987+
pipe.set_progress_bar_config(disable=None)
1988+
1989+
inputs = self.get_dummy_inputs(torch_device)
1990+
encode_prompt_signature = inspect.signature(pipe.encode_prompt)
1991+
encode_prompt_parameters = list(encode_prompt_signature.parameters.values())
1992+
1993+
# Required parameters in encode_prompt = those with no default
1994+
required_params = []
1995+
for param in encode_prompt_parameters:
1996+
if param.name == "self":
1997+
continue
1998+
if param.default is inspect.Parameter.empty:
1999+
required_params.append(param.name)
2000+
2001+
encode_prompt_param_names = [p.name for p in encode_prompt_parameters if p.name != "self"]
2002+
input_keys = list(inputs.keys())
2003+
encode_prompt_inputs = {k: inputs.pop(k) for k in input_keys if k in encode_prompt_param_names}
2004+
2005+
pipe_call_signature = inspect.signature(pipe.__call__)
2006+
pipe_call_parameters = pipe_call_signature.parameters
2007+
2008+
# For each required param in encode_prompt, check if it's missing
2009+
# in encode_prompt_inputs. If so, see if __call__ has a default
2010+
# for that param and use it if available.
2011+
for required_param_name in required_params:
2012+
if required_param_name not in encode_prompt_inputs:
2013+
pipe_call_param = pipe_call_parameters.get(required_param_name, None)
2014+
if pipe_call_param is not None and pipe_call_param.default is not inspect.Parameter.empty:
2015+
# Use the default from pipe.__call__
2016+
encode_prompt_inputs[required_param_name] = pipe_call_param.default
2017+
else:
2018+
raise ValueError(
2019+
f"Required parameter '{required_param_name}' in "
2020+
f"encode_prompt has no default in either encode_prompt or __call__."
2021+
)
2022+
2023+
with torch.no_grad():
2024+
encoded_prompt_outputs = pipe.encode_prompt(**encode_prompt_inputs)
2025+
2026+
prompt_embeds_kwargs = dict(zip(self.prompt_embed_kwargs, encoded_prompt_outputs))
2027+
adapted_prompt_embeds_kwargs = {
2028+
k: prompt_embeds_kwargs.pop(k) for k in list(prompt_embeds_kwargs.keys()) if k in pipe_call_parameters
2029+
}
2030+
pipe_out = pipe(**inputs, **adapted_prompt_embeds_kwargs)[0]
2031+
2032+
inputs = self.get_dummy_inputs(torch_device)
2033+
pipe_out_2 = pipe(**inputs)[0]
2034+
2035+
self.assertTrue(np.allclose(pipe_out, pipe_out_2, atol=1e-4, rtol=1e-4))
2036+
19792037
def test_StableDiffusionMixin_component(self):
19802038
"""Any pipeline that have LDMFuncMixin should have vae and unet components."""
19812039
if not issubclass(self.pipeline_class, StableDiffusionMixin):

0 commit comments

Comments
 (0)