@@ -986,6 +986,8 @@ class PipelineTesterMixin:
986986
987987 test_xformers_attention = True
988988
989+ prompt_embed_kwargs = ("prompt_embeds" , "negative_prompt_embeds" ) # most common return-type across the pipelines.
990+
989991 def get_generator (self , seed ):
990992 device = torch_device if torch_device != "mps" else "cpu"
991993 generator = torch .Generator (device ).manual_seed (seed )
@@ -1976,6 +1978,62 @@ def test_loading_with_incorrect_variants_raises_error(self):
19761978
19771979 assert f"You are trying to load the model files of the `variant={ variant } `" in str (error .exception )
19781980
1981+ def test_encode_prompt_works_in_isolation (self ):
1982+ if not hasattr (self .pipeline_class , "encode_prompt" ):
1983+ return
1984+ components = self .get_dummy_components ()
1985+ pipe = self .pipeline_class (** components )
1986+ pipe = pipe .to (torch_device )
1987+ pipe .set_progress_bar_config (disable = None )
1988+
1989+ inputs = self .get_dummy_inputs (torch_device )
1990+ encode_prompt_signature = inspect .signature (pipe .encode_prompt )
1991+ encode_prompt_parameters = list (encode_prompt_signature .parameters .values ())
1992+
1993+ # Required parameters in encode_prompt = those with no default
1994+ required_params = []
1995+ for param in encode_prompt_parameters :
1996+ if param .name == "self" :
1997+ continue
1998+ if param .default is inspect .Parameter .empty :
1999+ required_params .append (param .name )
2000+
2001+ encode_prompt_param_names = [p .name for p in encode_prompt_parameters if p .name != "self" ]
2002+ input_keys = list (inputs .keys ())
2003+ encode_prompt_inputs = {k : inputs .pop (k ) for k in input_keys if k in encode_prompt_param_names }
2004+
2005+ pipe_call_signature = inspect .signature (pipe .__call__ )
2006+ pipe_call_parameters = pipe_call_signature .parameters
2007+
2008+ # For each required param in encode_prompt, check if it's missing
2009+ # in encode_prompt_inputs. If so, see if __call__ has a default
2010+ # for that param and use it if available.
2011+ for required_param_name in required_params :
2012+ if required_param_name not in encode_prompt_inputs :
2013+ pipe_call_param = pipe_call_parameters .get (required_param_name , None )
2014+ if pipe_call_param is not None and pipe_call_param .default is not inspect .Parameter .empty :
2015+ # Use the default from pipe.__call__
2016+ encode_prompt_inputs [required_param_name ] = pipe_call_param .default
2017+ else :
2018+ raise ValueError (
2019+ f"Required parameter '{ required_param_name } ' in "
2020+ f"encode_prompt has no default in either encode_prompt or __call__."
2021+ )
2022+
2023+ with torch .no_grad ():
2024+ encoded_prompt_outputs = pipe .encode_prompt (** encode_prompt_inputs )
2025+
2026+ prompt_embeds_kwargs = dict (zip (self .prompt_embed_kwargs , encoded_prompt_outputs ))
2027+ adapted_prompt_embeds_kwargs = {
2028+ k : prompt_embeds_kwargs .pop (k ) for k in list (prompt_embeds_kwargs .keys ()) if k in pipe_call_parameters
2029+ }
2030+ pipe_out = pipe (** inputs , ** adapted_prompt_embeds_kwargs )[0 ]
2031+
2032+ inputs = self .get_dummy_inputs (torch_device )
2033+ pipe_out_2 = pipe (** inputs )[0 ]
2034+
2035+ self .assertTrue (np .allclose (pipe_out , pipe_out_2 , atol = 1e-4 , rtol = 1e-4 ))
2036+
19792037 def test_StableDiffusionMixin_component (self ):
19802038 """Any pipeline that have LDMFuncMixin should have vae and unet components."""
19812039 if not issubclass (self .pipeline_class , StableDiffusionMixin ):
0 commit comments