Skip to content

Commit 6ce9128

Browse files
committed
fix
1 parent bfa253a commit 6ce9128

File tree

1 file changed

+23
-4
lines changed

1 file changed

+23
-4
lines changed

tests/pipelines/test_pipelines_common.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1981,10 +1981,19 @@ def test_loading_with_incorrect_variants_raises_error(self):
19811981
def test_encode_prompt_works_in_isolation(self):
19821982
if not hasattr(self.pipeline_class, "encode_prompt"):
19831983
return
1984+
19841985
components = self.get_dummy_components()
1985-
pipe = self.pipeline_class(**components)
1986+
1987+
# We initialize the pipeline with only text encoders and tokenizers,
1988+
# mimicking a real-world scenario.
1989+
components_with_text_encoders = {}
1990+
for k in components:
1991+
if "text" in k or "tokenizer" in k:
1992+
components_with_text_encoders[k] = components[k]
1993+
else:
1994+
components_with_text_encoders[k] = None
1995+
pipe = self.pipeline_class(**components_with_text_encoders)
19861996
pipe = pipe.to(torch_device)
1987-
pipe.set_progress_bar_config(disable=None)
19881997

19891998
inputs = self.get_dummy_inputs(torch_device)
19901999
encode_prompt_signature = inspect.signature(pipe.encode_prompt)
@@ -2027,10 +2036,20 @@ def test_encode_prompt_works_in_isolation(self):
20272036
adapted_prompt_embeds_kwargs = {
20282037
k: prompt_embeds_kwargs.pop(k) for k in list(prompt_embeds_kwargs.keys()) if k in pipe_call_parameters
20292038
}
2030-
pipe_out = pipe(**inputs, **adapted_prompt_embeds_kwargs)[0]
20312039

2040+
# now initialize a pipeline without text encoders
2041+
components_with_text_encoders = {}
2042+
for k in components:
2043+
if "text" in k or "tokenizer" in k:
2044+
components_with_text_encoders[k] = None
2045+
else:
2046+
components_with_text_encoders[k] = components[k]
2047+
pipe_without_text_encoders = self.pipeline_class(**components_with_text_encoders).to(torch_device)
2048+
pipe_out = pipe_without_text_encoders(**inputs, **adapted_prompt_embeds_kwargs)[0]
2049+
2050+
full_pipe = self.pipeline_class(**components).to(torch_device)
20322051
inputs = self.get_dummy_inputs(torch_device)
2033-
pipe_out_2 = pipe(**inputs)[0]
2052+
pipe_out_2 = full_pipe(**inputs)[0]
20342053

20352054
self.assertTrue(np.allclose(pipe_out, pipe_out_2, atol=1e-4, rtol=1e-4))
20362055

0 commit comments

Comments
 (0)