-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Closed
Description
Is your feature request related to a problem? Please describe.
Contrary to other types of pipelines, the Stable Diffusion pipeline doesn't support directly using a HF Dataset as input, i.e.
base = DiffusionPipeline.from_pretrained"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True, batch_size=8)
dset = load_dataset('AIEnergyScore/image_generation', split = 'train').select(range(100))
base(dset)
gives the following error:
/usr/local/lib/python3.11/dist-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py in check_inputs(self, prompt, prompt_2, height, width, callback_steps, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs)
653 )
654 elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
--> 655 raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
656 elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
657 raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
ValueError: `prompt` has to be of type `str` or `list` but is <class 'datasets.arrow_dataset.Dataset'>
Describe the solution you'd like.
Adding support for Datasets in the .call() to the StableDiffusion Pipeline:
https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__
Describe alternatives you've considered.
I'm converting the prompts from the HF dataset to a list(), but I believe that this may be an issue for bigger datasets
Additional context.
Thank you @lhoestq for the support in figuring this out! π
lhoestq
Metadata
Metadata
Assignees
Labels
No labels