1+ # flake8: noqa
12import os
2- import io
3- import base64
43import threading
54import datetime
5+ import logging
66from typing import List , Optional
77
88import torch
1313from diffusers import DiffusionPipeline
1414
1515
16+ LOG_LEVEL = os .getenv ("LOG_LEVEL" , "INFO" ).upper ()
17+ logging .basicConfig (
18+ level = getattr (logging , LOG_LEVEL , logging .INFO ),
19+ format = "%(asctime)s %(levelname)s %(name)s %(message)s" ,
20+ )
21+ logger = logging .getLogger ("lcm_app" )
22+
1623app = FastAPI (title = "LCM Text2Image" )
1724
1825
@@ -47,6 +54,8 @@ def _create_pipeline() -> DiffusionPipeline:
4754 use_cuda = torch .cuda .is_available ()
4855 torch_dtype = torch .float16 if use_cuda else torch .float32
4956
57+ logger .info ("Initializing diffusion pipeline: model_id=%s, device=%s, dtype=%s, safety=%s" ,
58+ MODEL_ID , "cuda" if use_cuda else "cpu" , str (torch_dtype ), "enabled" if SAFETY_CHECKER != "disabled" else "disabled" )
5059 pipe = DiffusionPipeline .from_pretrained (
5160 MODEL_ID ,
5261 custom_pipeline = LCM_CUSTOM_PIPELINE ,
@@ -58,6 +67,7 @@ def _create_pipeline() -> DiffusionPipeline:
5867
5968 device = "cuda" if use_cuda else "cpu"
6069 pipe = pipe .to (device )
70+ logger .info ("Pipeline ready on device=%s" , device )
6171 return pipe
6272
6373
@@ -87,6 +97,7 @@ def _generate_filename(prompt: str, timestamp: str, index: int) -> str:
8797def _on_startup ():
8898 # Optionally preload model to avoid first-request latency
8999 if PRELOAD_MODEL :
100+ logger .info ("Preloading model on startup" )
90101 _get_pipeline ()
91102
92103
@@ -157,17 +168,17 @@ def healthz():
157168 </div>
158169 <div>
159170 <label>Steps</label>
160- <input type=\" number\" name=\" num_inference_steps\" min=\" 1\" max=\" 20\" value=\" "" + str( DEFAULT_STEPS) + "" " / >
171+ <input type=\" number\" name=\" num_inference_steps\" min=\" 1\" max=\" 20\" value=\" { DEFAULT_STEPS} \ " />
161172 </div>
162173 <div>
163174 <label>Guidance</label>
164- < input type = \" number \" step = \" 0.5 \" name = \" guidance_scale \" value = \" "" + str ( DEFAULT_GUIDANCE ) + "" " />
175+ <input type=\" number\" step=\" 0.5\" name=\" guidance_scale\" value=\" { DEFAULT_GUIDANCE} \ " />
165176 </div>
166177 </div>
167178 <div class=\" row\" >
168179 <div>
169180 <label>LCM Origin Steps</label>
170- <input type=\" number\" name=\" lcm_origin_steps\" min=\" 1\" max=\" 20\" value=\" "" + str( DEFAULT_LCM_ORIGIN_STEPS) + "" " / >
181+ <input type=\" number\" name=\" lcm_origin_steps\" min=\" 1\" max=\" 20\" value=\" { DEFAULT_LCM_ORIGIN_STEPS} \ " />
171182 </div>
172183 </div>
173184 <button id=\" go\" type=\" submit\" >Generate</button>
@@ -181,7 +192,13 @@ def healthz():
181192
182193@app .get ("/" , response_class = HTMLResponse )
183194def index ():
184- return HTMLResponse(INDEX_HTML)
195+ html = (
196+ INDEX_HTML
197+ .replace ("{DEFAULT_STEPS}" , str (DEFAULT_STEPS ))
198+ .replace ("{DEFAULT_GUIDANCE}" , str (DEFAULT_GUIDANCE ))
199+ .replace ("{DEFAULT_LCM_ORIGIN_STEPS}" , str (DEFAULT_LCM_ORIGIN_STEPS ))
200+ )
201+ return HTMLResponse (html )
185202
186203
187204@app .post ("/api/generate" )
@@ -199,16 +216,20 @@ def api_generate(
199216 guidance = guidance_scale or DEFAULT_GUIDANCE
200217 origin_steps = lcm_origin_steps or DEFAULT_LCM_ORIGIN_STEPS
201218
219+ logger .info ("Generation request: prompt=%r, num_images=%s, steps=%s, guidance=%s, origin_steps=%s" ,
220+ prompt [:80 ], num_images , steps , guidance , origin_steps )
202221 try :
203222 pipe = _get_pipeline ()
204223 images = pipe (
205224 prompt = prompt ,
206225 num_inference_steps = steps ,
207226 guidance_scale = guidance ,
208227 lcm_origin_steps = origin_steps ,
228+ num_images_per_prompt = max (1 , num_images ),
209229 output_type = "pil" ,
210230 ).images
211231 except Exception as e :
232+ logger .exception ("Generation failed: %s" , e )
212233 raise HTTPException (status_code = 500 , detail = f"Generation failed: { e } " )
213234
214235 ts = datetime .datetime .now ().strftime ("%m-%d-%H-%M-%S" )
@@ -221,13 +242,14 @@ def api_generate(
221242 try :
222243 _save_image (image , filepath , metadata )
223244 except Exception as e :
245+ logger .exception ("Failed to save image: %s" , e )
224246 raise HTTPException (status_code = 500 , detail = f"Failed to save image: { e } " )
225247 files .append ({
226248 "name" : filename ,
227249 "path" : filepath ,
228250 "url" : f"/outputs/{ filename } " ,
229251 })
230-
252+ logger . info ( "Generated %d images" , len ( files ))
231253 return JSONResponse ({"files" : files })
232254
233255
0 commit comments