Skip to content

Commit f1e58a6

Browse files
committed
updates.
1 parent e2f34ad commit f1e58a6

File tree

2 files changed

+84
-13
lines changed

2 files changed

+84
-13
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import ast
2+
import importlib
3+
import inspect
4+
import textwrap
5+
6+
7+
class ReturnNameVisitor(ast.NodeVisitor):
8+
"""Thanks to ChatGPT for pairing."""
9+
def __init__(self):
10+
self.return_names = []
11+
12+
def visit_Return(self, node):
13+
# Check if the return value is a tuple.
14+
if isinstance(node.value, ast.Tuple):
15+
for elt in node.value.elts:
16+
if isinstance(elt, ast.Name):
17+
self.return_names.append(elt.id)
18+
else:
19+
try:
20+
self.return_names.append(ast.unparse(elt))
21+
except Exception:
22+
self.return_names.append(str(elt))
23+
else:
24+
if isinstance(node.value, ast.Name):
25+
self.return_names.append(node.value.id)
26+
else:
27+
try:
28+
self.return_names.append(ast.unparse(node.value))
29+
except Exception:
30+
self.return_names.append(str(node.value))
31+
self.generic_visit(node)
32+
33+
def _determine_parent_module(self, cls):
34+
from diffusers import DiffusionPipeline
35+
from diffusers.models.modeling_utils import ModelMixin
36+
37+
if issubclass(cls, DiffusionPipeline):
38+
return "pipelines"
39+
elif issubclass(cls, ModelMixin):
40+
return "models"
41+
else:
42+
raise NotImplementedError
43+
44+
def get_ast_tree(self, cls, attribute_name="encode_prompt"):
45+
parent_module_name = self._determine_parent_module(cls)
46+
main_module = importlib.import_module(f"diffusers.{parent_module_name}")
47+
current_cls_module = getattr(main_module, cls.__name__)
48+
source_code = inspect.getsource(getattr(current_cls_module, attribute_name))
49+
source_code = textwrap.dedent(source_code)
50+
tree = ast.parse(source_code)
51+
return tree

tests/pipelines/test_pipelines_common.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import tempfile
66
import unittest
77
import uuid
8+
import textwrap
9+
import ast
810
from typing import Any, Callable, Dict, Union
911

1012
import numpy as np
@@ -14,7 +16,7 @@
1416
from huggingface_hub import ModelCard, delete_repo
1517
from huggingface_hub.utils import is_jinja_available
1618
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
17-
19+
import importlib
1820
import diffusers
1921
from diffusers import (
2022
AsymmetricAutoencoderKL,
@@ -51,6 +53,7 @@
5153
skip_mps,
5254
torch_device,
5355
)
56+
from diffusers.utils.source_code_parsing_utils import ReturnNameVisitor
5457

5558
from ..models.autoencoders.vae import (
5659
get_asym_autoencoder_kl_config,
@@ -1997,26 +2000,27 @@ def test_encode_prompt_works_in_isolation(self):
19972000
components_with_text_encoders[k] = components[k]
19982001
else:
19992002
components_with_text_encoders[k] = None
2000-
pipe = self.pipeline_class(**components_with_text_encoders)
2001-
pipe = pipe.to(torch_device)
2003+
pipe_with_just_text_encoder = self.pipeline_class(**components_with_text_encoders)
2004+
pipe_with_just_text_encoder = pipe_with_just_text_encoder.to(torch_device)
20022005

20032006
inputs = self.get_dummy_inputs(torch_device)
2004-
encode_prompt_signature = inspect.signature(pipe.encode_prompt)
2007+
encode_prompt_signature = inspect.signature(pipe_with_just_text_encoder.encode_prompt)
20052008
encode_prompt_parameters = list(encode_prompt_signature.parameters.values())
20062009

2007-
# Required parameters in encode_prompt = those with no default
2010+
# Required parameters in encode_prompt with those with no default
20082011
required_params = []
20092012
for param in encode_prompt_parameters:
2010-
if param.name == "self":
2013+
if param.name == "self" or param.name == "kwargs":
20112014
continue
20122015
if param.default is inspect.Parameter.empty:
20132016
required_params.append(param.name)
20142017

2018+
# Craft inputs for the `encode_prompt()` method to run in isolation.
20152019
encode_prompt_param_names = [p.name for p in encode_prompt_parameters if p.name != "self"]
20162020
input_keys = list(inputs.keys())
20172021
encode_prompt_inputs = {k: inputs.pop(k) for k in input_keys if k in encode_prompt_param_names}
20182022

2019-
pipe_call_signature = inspect.signature(pipe.__call__)
2023+
pipe_call_signature = inspect.signature(pipe_with_just_text_encoder.__call__)
20202024
pipe_call_parameters = pipe_call_signature.parameters
20212025

20222026
# For each required param in encode_prompt, check if it's missing
@@ -2034,28 +2038,44 @@ def test_encode_prompt_works_in_isolation(self):
20342038
f"encode_prompt has no default in either encode_prompt or __call__."
20352039
)
20362040

2041+
# Compute `encode_prompt()`.
20372042
with torch.no_grad():
2038-
encoded_prompt_outputs = pipe.encode_prompt(**encode_prompt_inputs)
2039-
2040-
prompt_embeds_kwargs = dict(zip(self.prompt_embed_kwargs, encoded_prompt_outputs))
2043+
encoded_prompt_outputs = pipe_with_just_text_encoder.encode_prompt(**encode_prompt_inputs)
2044+
2045+
# Programatically determine the reutrn names of `encode_prompt.`
2046+
ast_vistor = ReturnNameVisitor()
2047+
encode_prompt_tree = ast_vistor.get_ast_tree(cls=self.pipeline_class)
2048+
ast_vistor.visit(encode_prompt_tree)
2049+
prompt_embed_kwargs = ast_vistor.return_names
2050+
print(f"{prompt_embed_kwargs=}")
2051+
prompt_embeds_kwargs = dict(zip(prompt_embed_kwargs, encoded_prompt_outputs))
2052+
# Pack the outputs of `encode_prompt`.
20412053
adapted_prompt_embeds_kwargs = {
20422054
k: prompt_embeds_kwargs.pop(k) for k in list(prompt_embeds_kwargs.keys()) if k in pipe_call_parameters
20432055
}
20442056

2045-
# now initialize a pipeline without text encoders
2057+
# now initialize a pipeline without text encoders and compute outputs with the
2058+
# `encode_prompt()` outputs and other relevant inputs.
20462059
components_with_text_encoders = {}
20472060
for k in components:
20482061
if "text" in k or "tokenizer" in k:
20492062
components_with_text_encoders[k] = None
20502063
else:
20512064
components_with_text_encoders[k] = components[k]
20522065
pipe_without_text_encoders = self.pipeline_class(**components_with_text_encoders).to(torch_device)
2053-
pipe_out = pipe_without_text_encoders(**inputs, **adapted_prompt_embeds_kwargs)[0]
20542066

2067+
# Set `negative_prompt` to None as we have already calculated its embeds
2068+
# if it was present in `inputs`. This is because otherwise we will interfere wrongly
2069+
# for non-None `negative_prompt` values as defaults (PixArt for example).
2070+
pipe_without_tes_inputs = {**inputs, **adapted_prompt_embeds_kwargs}
2071+
if pipe_call_parameters.get("negative_prompt", None) is not None:
2072+
pipe_without_tes_inputs.update({"negative_prompt": None})
2073+
pipe_out = pipe_without_text_encoders(**pipe_without_tes_inputs)[0]
2074+
2075+
# Compare against regular pipeline outputs.
20552076
full_pipe = self.pipeline_class(**components).to(torch_device)
20562077
inputs = self.get_dummy_inputs(torch_device)
20572078
pipe_out_2 = full_pipe(**inputs)[0]
2058-
20592079
self.assertTrue(np.allclose(pipe_out, pipe_out_2, atol=1e-4, rtol=1e-4))
20602080

20612081
def test_StableDiffusionMixin_component(self):

0 commit comments

Comments
 (0)