55import tempfile
66import unittest
77import uuid
8+ import textwrap
9+ import ast
810from typing import Any , Callable , Dict , Union
911
1012import numpy as np
1416from huggingface_hub import ModelCard , delete_repo
1517from huggingface_hub .utils import is_jinja_available
1618from transformers import CLIPTextConfig , CLIPTextModel , CLIPTokenizer
17-
19+ import importlib
1820import diffusers
1921from diffusers import (
2022 AsymmetricAutoencoderKL ,
5153 skip_mps ,
5254 torch_device ,
5355)
56+ from diffusers .utils .source_code_parsing_utils import ReturnNameVisitor
5457
5558from ..models .autoencoders .vae import (
5659 get_asym_autoencoder_kl_config ,
@@ -1997,26 +2000,27 @@ def test_encode_prompt_works_in_isolation(self):
19972000 components_with_text_encoders [k ] = components [k ]
19982001 else :
19992002 components_with_text_encoders [k ] = None
2000- pipe = self .pipeline_class (** components_with_text_encoders )
2001- pipe = pipe .to (torch_device )
2003+ pipe_with_just_text_encoder = self .pipeline_class (** components_with_text_encoders )
2004+ pipe_with_just_text_encoder = pipe_with_just_text_encoder .to (torch_device )
20022005
20032006 inputs = self .get_dummy_inputs (torch_device )
2004- encode_prompt_signature = inspect .signature (pipe .encode_prompt )
2007+ encode_prompt_signature = inspect .signature (pipe_with_just_text_encoder .encode_prompt )
20052008 encode_prompt_parameters = list (encode_prompt_signature .parameters .values ())
20062009
2007- # Required parameters in encode_prompt = those with no default
2010+ # Required parameters in encode_prompt with those with no default
20082011 required_params = []
20092012 for param in encode_prompt_parameters :
2010- if param .name == "self" :
2013+ if param .name == "self" or param . name == "kwargs" :
20112014 continue
20122015 if param .default is inspect .Parameter .empty :
20132016 required_params .append (param .name )
20142017
2018+ # Craft inputs for the `encode_prompt()` method to run in isolation.
20152019 encode_prompt_param_names = [p .name for p in encode_prompt_parameters if p .name != "self" ]
20162020 input_keys = list (inputs .keys ())
20172021 encode_prompt_inputs = {k : inputs .pop (k ) for k in input_keys if k in encode_prompt_param_names }
20182022
2019- pipe_call_signature = inspect .signature (pipe .__call__ )
2023+ pipe_call_signature = inspect .signature (pipe_with_just_text_encoder .__call__ )
20202024 pipe_call_parameters = pipe_call_signature .parameters
20212025
20222026 # For each required param in encode_prompt, check if it's missing
@@ -2034,28 +2038,44 @@ def test_encode_prompt_works_in_isolation(self):
20342038 f"encode_prompt has no default in either encode_prompt or __call__."
20352039 )
20362040
2041+ # Compute `encode_prompt()`.
20372042 with torch .no_grad ():
2038- encoded_prompt_outputs = pipe .encode_prompt (** encode_prompt_inputs )
2039-
2040- prompt_embeds_kwargs = dict (zip (self .prompt_embed_kwargs , encoded_prompt_outputs ))
2043+ encoded_prompt_outputs = pipe_with_just_text_encoder .encode_prompt (** encode_prompt_inputs )
2044+
2045+ # Programatically determine the reutrn names of `encode_prompt.`
2046+ ast_vistor = ReturnNameVisitor ()
2047+ encode_prompt_tree = ast_vistor .get_ast_tree (cls = self .pipeline_class )
2048+ ast_vistor .visit (encode_prompt_tree )
2049+ prompt_embed_kwargs = ast_vistor .return_names
2050+ print (f"{ prompt_embed_kwargs = } " )
2051+ prompt_embeds_kwargs = dict (zip (prompt_embed_kwargs , encoded_prompt_outputs ))
2052+ # Pack the outputs of `encode_prompt`.
20412053 adapted_prompt_embeds_kwargs = {
20422054 k : prompt_embeds_kwargs .pop (k ) for k in list (prompt_embeds_kwargs .keys ()) if k in pipe_call_parameters
20432055 }
20442056
2045- # now initialize a pipeline without text encoders
2057+ # now initialize a pipeline without text encoders and compute outputs with the
2058+ # `encode_prompt()` outputs and other relevant inputs.
20462059 components_with_text_encoders = {}
20472060 for k in components :
20482061 if "text" in k or "tokenizer" in k :
20492062 components_with_text_encoders [k ] = None
20502063 else :
20512064 components_with_text_encoders [k ] = components [k ]
20522065 pipe_without_text_encoders = self .pipeline_class (** components_with_text_encoders ).to (torch_device )
2053- pipe_out = pipe_without_text_encoders (** inputs , ** adapted_prompt_embeds_kwargs )[0 ]
20542066
2067+ # Set `negative_prompt` to None as we have already calculated its embeds
2068+ # if it was present in `inputs`. This is because otherwise we will interfere wrongly
2069+ # for non-None `negative_prompt` values as defaults (PixArt for example).
2070+ pipe_without_tes_inputs = {** inputs , ** adapted_prompt_embeds_kwargs }
2071+ if pipe_call_parameters .get ("negative_prompt" , None ) is not None :
2072+ pipe_without_tes_inputs .update ({"negative_prompt" : None })
2073+ pipe_out = pipe_without_text_encoders (** pipe_without_tes_inputs )[0 ]
2074+
2075+ # Compare against regular pipeline outputs.
20552076 full_pipe = self .pipeline_class (** components ).to (torch_device )
20562077 inputs = self .get_dummy_inputs (torch_device )
20572078 pipe_out_2 = full_pipe (** inputs )[0 ]
2058-
20592079 self .assertTrue (np .allclose (pipe_out , pipe_out_2 , atol = 1e-4 , rtol = 1e-4 ))
20602080
20612081 def test_StableDiffusionMixin_component (self ):
0 commit comments