11import importlib .util
2- import json
3- import os
2+ import logging
43
54_diffusers = importlib .util .find_spec ("diffusers" ) is not None
65
@@ -11,60 +10,46 @@ def is_diffusers_available():
1110
1211if is_diffusers_available ():
1312 import torch
14- from diffusers import DPMSolverMultistepScheduler , StableDiffusionPipeline
13+ from diffusers import AutoPipelineForText2Image , DPMSolverMultistepScheduler , StableDiffusionPipeline
1514
15+ logger = logging .getLogger (__name__ )
16+ logging .basicConfig (format = "%(asctime)s | %(levelname)s | %(message)s" , level = logging .INFO )
1617
17- def check_supported_pipeline (model_dir ):
18- try :
19- with open (os .path .join (model_dir , "model_index.json" )) as json_file :
20- data = json .load (json_file )
21- if data ["_class_name" ] == "StableDiffusionPipeline" :
22- return True
23- else :
24- return False
25- except Exception :
26- return False
2718
28-
29- class DiffusersPipelineImageToText :
19+ class IEAutoPipelineForText2Image :
3020 def __init__ (self , model_dir : str , device : str = None ): # needs "cuda" for GPU
31- self .pipeline = StableDiffusionPipeline .from_pretrained (model_dir , torch_dtype = torch .float16 )
21+ dtype = torch .float32
22+ if device == "cuda" :
23+ dtype = torch .float16
24+ device_map = "auto" if device == "cuda" else None
25+
26+ self .pipeline = AutoPipelineForText2Image .from_pretrained (model_dir , torch_dtype = dtype , device_map = device_map )
3227 # try to use DPMSolverMultistepScheduler
33- try :
34- self .pipeline .scheduler = DPMSolverMultistepScheduler .from_config (self .pipeline .scheduler .config )
35- except Exception :
36- pass
28+ if isinstance (self .pipeline , StableDiffusionPipeline ):
29+ try :
30+ self .pipeline .scheduler = DPMSolverMultistepScheduler .from_config (self .pipeline .scheduler .config )
31+ except Exception :
32+ pass
3733 self .pipeline .to (device )
3834
3935 def __call__ (
4036 self ,
4137 prompt ,
42- num_inference_steps = 25 ,
43- guidance_scale = 7.5 ,
44- num_images_per_prompt = 1 ,
45- height = None ,
46- width = None ,
47- negative_prompt = None ,
38+ ** kwargs ,
4839 ):
4940 # TODO: add support for more images (Reason is correct output)
50- num_images_per_prompt = 1
41+ if "num_images_per_prompt" in kwargs :
42+ kwargs .pop ("num_images_per_prompt" )
43+ logger .warning ("Sending num_images_per_prompt > 1 to pipeline is not supported. Using default value 1." )
5144
5245 # Call pipeline with parameters
53- out = self .pipeline (
54- prompt ,
55- num_inference_steps = num_inference_steps ,
56- guidance_scale = guidance_scale ,
57- num_images_per_prompt = num_images_per_prompt ,
58- negative_prompt = negative_prompt ,
59- height = height ,
60- width = width ,
61- )
46+ out = self .pipeline (prompt , num_images_per_prompt = 1 )
6247
6348 return out .images [0 ]
6449
6550
6651DIFFUSERS_TASKS = {
67- "text-to-image" : DiffusersPipelineImageToText ,
52+ "text-to-image" : IEAutoPipelineForText2Image ,
6853}
6954
7055
0 commit comments