Skip to content
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
bfa253a
poc encode_prompt() tests
sayakpaul Jan 3, 2025
6ce9128
fix
sayakpaul Jan 3, 2025
d0ac6c2
Merge branch 'main' into tests-encode-prompt
sayakpaul Jan 7, 2025
8a1d84a
Merge branch 'main' into tests-encode-prompt
sayakpaul Jan 8, 2025
0180f9e
Merge branch 'main' into tests-encode-prompt
sayakpaul Jan 9, 2025
f87fa8e
Merge branch 'main' into tests-encode-prompt
sayakpaul Jan 10, 2025
a4a917e
Merge branch 'main' into tests-encode-prompt
sayakpaul Jan 30, 2025
e2f34ad
Merge branch 'main' into tests-encode-prompt
sayakpaul Feb 11, 2025
f1e58a6
updates.
sayakpaul Feb 11, 2025
f43dd95
fixes
sayakpaul Feb 11, 2025
9a2ec46
fixes
sayakpaul Feb 11, 2025
1767341
updates
sayakpaul Feb 11, 2025
2ea8313
updates
sayakpaul Feb 11, 2025
af4e8d0
updates
sayakpaul Feb 12, 2025
3e15f7c
revert
sayakpaul Feb 12, 2025
b148bab
updates
sayakpaul Feb 12, 2025
8c004ea
updates
sayakpaul Feb 13, 2025
7614b3c
Merge branch 'main' into tests-encode-prompt
sayakpaul Feb 13, 2025
6734a12
Merge branch 'main' into tests-encode-prompt
sayakpaul Feb 14, 2025
ffe821c
updates
sayakpaul Feb 15, 2025
4e39d21
Merge branch 'main' into tests-encode-prompt
sayakpaul Feb 15, 2025
8200b27
updates
sayakpaul Feb 17, 2025
c31d209
Merge branch 'main' into tests-encode-prompt
sayakpaul Feb 17, 2025
464edca
Merge branch 'main' into tests-encode-prompt
sayakpaul Feb 18, 2025
a3f19c3
Merge branch 'main' into tests-encode-prompt
sayakpaul Feb 19, 2025
76aaab0
remove SDXLOptionalComponentsTesterMixin.
sayakpaul Feb 20, 2025
b1c9666
remove tests that directly leveraged encode_prompt() in some way or t…
sayakpaul Feb 20, 2025
7068529
Merge branch 'main' into tests-encode-prompt
sayakpaul Feb 20, 2025
11b5fbd
fix imports.
sayakpaul Feb 20, 2025
6e613f9
remove _save_load
sayakpaul Feb 20, 2025
9b2e58d
fixes
sayakpaul Feb 20, 2025
28e26ea
fixes
sayakpaul Feb 20, 2025
7d7599f
fixes
sayakpaul Feb 20, 2025
93bff6c
Merge branch 'main' into tests-encode-prompt
sayakpaul Feb 20, 2025
a33ac2f
fixes
sayakpaul Feb 20, 2025
9d39ab2
Merge branch 'main' into tests-encode-prompt
sayakpaul Feb 20, 2025
cfda21b
Merge branch 'main' into tests-encode-prompt
sayakpaul Feb 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/diffusers/pipelines/pag/pipeline_pag_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ def encode_prompt(
else:
batch_size = prompt_embeds.shape[0]

self.tokenizer.padding_side = "right"
if getattr(self, "tokenizer", None) is not None:
self.tokenizer.padding_side = "right"

# See Section 3.1. of the paper.
max_length = max_sequence_length
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/pipelines/sana/pipeline_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ def encode_prompt(
else:
batch_size = prompt_embeds.shape[0]

self.tokenizer.padding_side = "right"
if getattr(self, "tokenizer", None) is not None:
self.tokenizer.padding_side = "right"

# See Section 3.1. of the paper.
max_length = max_sequence_length
Expand Down
52 changes: 52 additions & 0 deletions src/diffusers/utils/source_code_parsing_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import ast
import importlib
import inspect
import textwrap


class ReturnNameVisitor(ast.NodeVisitor):
"""Thanks to ChatGPT for pairing."""

def __init__(self):
self.return_names = []

def visit_Return(self, node):
# Check if the return value is a tuple.
if isinstance(node.value, ast.Tuple):
for elt in node.value.elts:
if isinstance(elt, ast.Name):
self.return_names.append(elt.id)
else:
try:
self.return_names.append(ast.unparse(elt))
except Exception:
self.return_names.append(str(elt))
else:
if isinstance(node.value, ast.Name):
self.return_names.append(node.value.id)
else:
try:
self.return_names.append(ast.unparse(node.value))
except Exception:
self.return_names.append(str(node.value))
self.generic_visit(node)

def _determine_parent_module(self, cls):
from diffusers import DiffusionPipeline
from diffusers.models.modeling_utils import ModelMixin

if issubclass(cls, DiffusionPipeline):
return "pipelines"
elif issubclass(cls, ModelMixin):
return "models"
else:
raise NotImplementedError

def get_ast_tree(self, cls, attribute_name="encode_prompt"):
parent_module_name = self._determine_parent_module(cls)
main_module = importlib.import_module(f"diffusers.{parent_module_name}")
current_cls_module = getattr(main_module, cls.__name__)
source_code = inspect.getsource(getattr(current_cls_module, attribute_name))
source_code = textwrap.dedent(source_code)
tree = ast.parse(source_code)
return tree
8 changes: 8 additions & 0 deletions tests/pipelines/animatediff/test_animatediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,14 @@ def test_xformers_attention_forwardGenerator_pass(self):
def test_vae_slicing(self):
return super().test_vae_slicing(image_count=2)

def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"num_images_per_prompt": 1,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)


@slow
@require_torch_accelerator
Expand Down
8 changes: 8 additions & 0 deletions tests/pipelines/animatediff/test_animatediff_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,3 +517,11 @@ def test_vae_slicing(self, video_count=2):
output_2 = pipe(**inputs)

assert np.abs(output_2[0].flatten() - output_1[0].flatten()).max() < 1e-2

def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"num_images_per_prompt": 1,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
4 changes: 4 additions & 0 deletions tests/pipelines/animatediff/test_animatediff_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,7 @@ def test_xformers_attention_forwardGenerator_pass(self):

max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results")

@unittest.skip("Test currently not supported.")
def test_encode_prompt_works_in_isolation(self):
pass
8 changes: 8 additions & 0 deletions tests/pipelines/animatediff/test_animatediff_sparsectrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,11 @@ def test_free_init_with_schedulers(self):

def test_vae_slicing(self):
return super().test_vae_slicing(image_count=2)

def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"num_images_per_prompt": 1,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
8 changes: 8 additions & 0 deletions tests/pipelines/animatediff/test_animatediff_video2video.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,3 +544,11 @@ def test_free_noise_multi_prompt(self):
inputs["strength"] = 0.5
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"}
pipe(**inputs).frames[0]

def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"num_images_per_prompt": 1,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
Original file line number Diff line number Diff line change
Expand Up @@ -533,3 +533,11 @@ def test_free_noise_multi_prompt(self):
inputs["strength"] = 0.5
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"}
pipe(**inputs).frames[0]

def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"num_images_per_prompt": 1,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
5 changes: 5 additions & 0 deletions tests/pipelines/audioldm2/test_audioldm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,9 +508,14 @@ def test_to_dtype(self):
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))

@unittest.skip("Test not supported.")
def test_sequential_cpu_offload_forward_pass(self):
pass

@unittest.skip("Test not supported for now because of the use of `projection_model` in `encode_prompt()`.")
def test_encode_prompt_works_in_isolation(self):
pass


@nightly
class AudioLDM2PipelineSlowTests(unittest.TestCase):
Expand Down
4 changes: 4 additions & 0 deletions tests/pipelines/blipdiffusion/test_blipdiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,7 @@ def test_blipdiffusion(self):
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}"

@unittest.skip("Test not supported because of complexities in deriving query_embeds.")
def test_encode_prompt_works_in_isolation(self):
pass
3 changes: 3 additions & 0 deletions tests/pipelines/cogview3/test_cogview3plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ def test_attention_slicing_forward_pass(
"Attention slicing should not affect the inference results",
)

def test_encode_prompt_works_in_isolation(self):
return super().test_encode_prompt_works_in_isolation(atol=1e-3, rtol=1e-3)


@slow
@require_torch_accelerator
Expand Down
21 changes: 21 additions & 0 deletions tests/pipelines/controlnet/test_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,13 @@ def test_controlnet_lcm_custom_timesteps(self):

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)


class StableDiffusionMultiControlNetPipelineFastTests(
IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
Expand Down Expand Up @@ -522,6 +529,13 @@ def test_inference_multiple_prompt_input(self):

assert image.shape == (4, 64, 64, 3)

def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)


class StableDiffusionMultiControlNetOneModelPipelineFastTests(
IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
Expand Down Expand Up @@ -707,6 +721,13 @@ def test_save_pretrained_raise_not_implemented_exception(self):
except NotImplementedError:
pass

def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)


@slow
@require_torch_accelerator
Expand Down
4 changes: 4 additions & 0 deletions tests/pipelines/controlnet/test_controlnet_blip_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,7 @@ def test_blipdiffusion_controlnet(self):
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"

@unittest.skip("Test not supported because of complexities in deriving query_embeds.")
def test_encode_prompt_works_in_isolation(self):
pass
14 changes: 14 additions & 0 deletions tests/pipelines/controlnet/test_controlnet_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,13 @@ def test_xformers_attention_forwardGenerator_pass(self):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)

def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)


class StableDiffusionMultiControlNetPipelineFastTests(
IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
Expand Down Expand Up @@ -391,6 +398,13 @@ def test_save_pretrained_raise_not_implemented_exception(self):
except NotImplementedError:
pass

def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)


@slow
@require_torch_accelerator
Expand Down
14 changes: 14 additions & 0 deletions tests/pipelines/controlnet/test_controlnet_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,13 @@ def test_xformers_attention_forwardGenerator_pass(self):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)

def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)


class ControlNetSimpleInpaintPipelineFastTests(ControlNetInpaintPipelineFastTests):
pipeline_class = StableDiffusionControlNetInpaintPipeline
Expand Down Expand Up @@ -443,6 +450,13 @@ def test_save_pretrained_raise_not_implemented_exception(self):
except NotImplementedError:
pass

def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)


@slow
@require_torch_accelerator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ def test_save_load_optional_components(self):
# TODO(YiYi) need to fix later
pass

@unittest.skip(
"Test not supported as `encode_prompt` is called two times separately which deivates from about 99% of the pipelines we have."
)
def test_encode_prompt_works_in_isolation(self):
pass


@slow
@require_torch_accelerator
Expand Down
7 changes: 7 additions & 0 deletions tests/pipelines/controlnet_xs/test_controlnetxs.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,13 @@ def test_to_device(self):
output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)

def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)


@slow
@require_torch_accelerator
Expand Down
6 changes: 6 additions & 0 deletions tests/pipelines/hunyuan_dit/test_hunyuan_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,12 @@ def test_fused_qkv_projections(self):
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."

@unittest.skip(
"Test not supported as `encode_prompt` is called two times separately which deivates from about 99% of the pipelines we have."
)
def test_encode_prompt_works_in_isolation(self):
pass


@slow
@require_torch_accelerator
Expand Down
4 changes: 4 additions & 0 deletions tests/pipelines/i2vgen_xl/test_i2vgenxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,10 @@ def test_num_videos_per_prompt(self):

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

@unittest.skip("Test not supported for now.")
def test_encode_prompt_works_in_isolation(self):
pass


@slow
@require_torch_accelerator
Expand Down
4 changes: 4 additions & 0 deletions tests/pipelines/kolors/test_kolors_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,7 @@ def test_inference_batch_single_identical(self):

def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=7e-2)

@unittest.skip("Test not supported because kolors img2img doesn't take pooled embeds as inputs unline kolors t2i.")
def test_encode_prompt_works_in_isolation(self):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,13 @@ def callback_inputs_test(pipe, i, t, callback_kwargs):
output = pipe(**inputs)[0]
assert output.abs().sum() == 0

def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)


@slow
@require_torch_gpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,13 @@ def callback_inputs_test(pipe, i, t, callback_kwargs):
output = pipe(**inputs)[0]
assert output.abs().sum() == 0

def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)


@slow
@require_torch_gpu
Expand Down
4 changes: 4 additions & 0 deletions tests/pipelines/latte/test_latte.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ def test_save_load_optional_components(self):
def test_xformers_attention_forwardGenerator_pass(self):
super()._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)

@unittest.skip("Test not supported because `encode_prompt()` has multiple returns.")
def test_encode_prompt_works_in_isolation(self):
pass


@slow
@require_torch_gpu
Expand Down
8 changes: 8 additions & 0 deletions tests/pipelines/pag/test_pag_animatediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,3 +553,11 @@ def test_pag_applied_layers(self):
pag_layers = ["motion_modules.42"]
with self.assertRaises(ValueError):
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)

def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"num_images_per_prompt": 1,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
Loading
Loading