55import tempfile
66import unittest
77import uuid
8- from typing import Any , Callable , Dict , Union
8+ from typing import Any , Callable , Dict , Optional , Union
99
1010import numpy as np
1111import PIL .Image
@@ -2069,20 +2069,26 @@ def test_loading_with_incorrect_variants_raises_error(self):
20692069
20702070 assert f"You are trying to load the model files of the `variant={ variant } `" in str (error .exception )
20712071
2072- def test_encode_prompt_works_in_isolation (self , extra_required_param_value_dict = None , atol = 1e-4 , rtol = 1e-4 ):
2072+ def test_encode_prompt_works_in_isolation (
2073+ self ,
2074+ extra_required_param_value_dict : Optional [dict ] = None ,
2075+ keep_params : Optional [list ] = None ,
2076+ atol = 1e-4 ,
2077+ rtol = 1e-4 ,
2078+ ):
20732079 if not hasattr (self .pipeline_class , "encode_prompt" ):
20742080 return
20752081
20762082 components = self .get_dummy_components ()
20772083
2084+ def _contains_text_key (name ):
2085+ return any (token in name for token in ("text" , "tokenizer" , "processor" ))
2086+
20782087 # We initialize the pipeline with only text encoders and tokenizers,
20792088 # mimicking a real-world scenario.
2080- components_with_text_encoders = {}
2081- for k in components :
2082- if "text" in k or "tokenizer" in k :
2083- components_with_text_encoders [k ] = components [k ]
2084- else :
2085- components_with_text_encoders [k ] = None
2089+ components_with_text_encoders = {
2090+ name : component if _contains_text_key (name ) else None for name , component in components .items ()
2091+ }
20862092 pipe_with_just_text_encoder = self .pipeline_class (** components_with_text_encoders )
20872093 pipe_with_just_text_encoder = pipe_with_just_text_encoder .to (torch_device )
20882094
@@ -2092,17 +2098,19 @@ def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=
20922098 encode_prompt_parameters = list (encode_prompt_signature .parameters .values ())
20932099
20942100 # Required args in encode_prompt with those with no default.
2095- required_params = []
2096- for param in encode_prompt_parameters :
2097- if param .name == "self" or param .name == "kwargs" :
2098- continue
2099- if param .default is inspect .Parameter .empty :
2100- required_params .append (param .name )
2101+ required_params = [
2102+ param .name
2103+ for param in encode_prompt_parameters
2104+ if param .name not in {"self" , "kwargs" } and param .default is inspect .Parameter .empty
2105+ ]
21012106
21022107 # Craft inputs for the `encode_prompt()` method to run in isolation.
21032108 encode_prompt_param_names = [p .name for p in encode_prompt_parameters if p .name != "self" ]
2104- input_keys = list (inputs .keys ())
2105- encode_prompt_inputs = {k : inputs .pop (k ) for k in input_keys if k in encode_prompt_param_names }
2109+ encode_prompt_inputs = {name : inputs [name ] for name in encode_prompt_param_names if name in inputs }
2110+ if keep_params :
2111+ for name in encode_prompt_param_names :
2112+ if name in inputs and name not in keep_params :
2113+ inputs .pop (name )
21062114
21072115 pipe_call_signature = inspect .signature (pipe_with_just_text_encoder .__call__ )
21082116 pipe_call_parameters = pipe_call_signature .parameters
@@ -2137,18 +2145,15 @@ def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=
21372145
21382146 # Pack the outputs of `encode_prompt`.
21392147 adapted_prompt_embeds_kwargs = {
2140- k : prompt_embeds_kwargs . pop ( k ) for k in list ( prompt_embeds_kwargs . keys ()) if k in pipe_call_parameters
2148+ name : prompt_embeds_kwargs [ name ] for name in prompt_embeds_kwargs if name in pipe_call_parameters
21412149 }
21422150
21432151 # now initialize a pipeline without text encoders and compute outputs with the
21442152 # `encode_prompt()` outputs and other relevant inputs.
2145- components_with_text_encoders = {}
2146- for k in components :
2147- if "text" in k or "tokenizer" in k :
2148- components_with_text_encoders [k ] = None
2149- else :
2150- components_with_text_encoders [k ] = components [k ]
2151- pipe_without_text_encoders = self .pipeline_class (** components_with_text_encoders ).to (torch_device )
2153+ components_without_text_encoders = {
2154+ name : None if _contains_text_key (name ) else component for name , component in components .items ()
2155+ }
2156+ pipe_without_text_encoders = self .pipeline_class (** components_without_text_encoders ).to (torch_device )
21522157
21532158 # Set `negative_prompt` to None as we have already calculated its embeds
21542159 # if it was present in `inputs`. This is because otherwise we will interfere wrongly
0 commit comments