11from diffusers .pipelines .stable_diffusion_3 .pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
2- from diffusers .pipelines .flux .pipeline_flux import FluxPipeline
32import torch
43import os
54import logging
@@ -20,7 +19,6 @@ class TextToImageInput(BaseModel):
2019class PresetModels :
2120 SD3 : List [str ] = field (default_factory = lambda : ['stabilityai/stable-diffusion-3-medium' ])
2221 SD3_5 : List [str ] = field (default_factory = lambda : ['stabilityai/stable-diffusion-3.5-large' , 'stabilityai/stable-diffusion-3.5-large-turbo' , 'stabilityai/stable-diffusion-3.5-medium' ])
23- Flux : List [str ] = field (default_factory = lambda : ['black-forest-labs/FLUX.1-dev' , 'black-forest-labs/FLUX.1-schnell' ])
2422
2523class TextToImagePipelineSD3 :
2624 def __init__ (self , model_path : str | None = None ):
@@ -48,37 +46,6 @@ def start(self):
4846 else :
4947 raise Exception ("No CUDA or MPS device available" )
5048
51- class TextToImagePipelineFlux :
52- def __init__ (self , model_path : str | None = None , low_vram : bool = False ):
53- self .model_path = model_path or os .getenv ("MODEL_PATH" )
54- self .pipeline : FluxPipeline | None = None
55- self .device : str | None = None
56- self .low_vram = low_vram
57-
58- def start (self ):
59- if torch .cuda .is_available ():
60- model_path = self .model_path or "black-forest-labs/FLUX.1-schnell"
61- logger .info ("Loading CUDA" )
62- self .device = "cuda"
63- self .pipeline = FluxPipeline .from_pretrained (
64- model_path ,
65- torch_dtype = torch .bfloat16 ,
66- ).to (device = self .device )
67- if self .low_vram :
68- self .pipeline .enable_model_cpu_offload ()
69- else :
70- pass
71- elif torch .backends .mps .is_available ():
72- model_path = self .model_path or "black-forest-labs/FLUX.1-schnell"
73- logger .info ("Loading MPS for Mac M Series" )
74- self .device = "mps"
75- self .pipeline = FluxPipeline .from_pretrained (
76- model_path ,
77- torch_dtype = torch .bfloat16 ,
78- ).to (device = self .device )
79- else :
80- raise Exception ("No CUDA or MPS device available" )
81-
8249class ModelPipelineInitializer :
8350 def __init__ (self , model : str = '' , type_models : str = 't2im' ):
8451 self .model = model
@@ -99,15 +66,11 @@ def initialize_pipeline(self):
9966 self .model_type = "SD3"
10067 elif self .model in preset_models .SD3_5 :
10168 self .model_type = "SD3_5"
102- elif self .model in preset_models .Flux :
103- self .model_type = "Flux"
10469
10570 # Create appropriate pipeline based on model type and type_models
10671 if self .type_models == 't2im' :
10772 if self .model_type in ["SD3" , "SD3_5" ]:
10873 self .pipeline = TextToImagePipelineSD3 (self .model )
109- elif self .model_type == "Flux" :
110- self .pipeline = TextToImagePipelineFlux (self .model )
11174 else :
11275 raise ValueError (f"Model type { self .model_type } not supported for text-to-image" )
11376 elif self .type_models == 't2v' :
0 commit comments