1
- from modules .api .models import StableDiffusionTxt2ImgProcessingAPI , StableDiffusionImg2ImgProcessingAPI
1
+ from modules .api .models import StableDiffusionTxt2ImgProcessingAPI , StableDiffusionImg2ImgProcessingAPI , InterrogateAPI
2
2
from modules .processing import StableDiffusionProcessingTxt2Img , StableDiffusionProcessingImg2Img , process_images
3
3
from modules .sd_samplers import all_samplers
4
4
from modules .extras import run_pnginfo
@@ -25,6 +25,11 @@ class ImageToImageResponse(BaseModel):
25
25
parameters : Json
26
26
info : Json
27
27
28
+ class InterrogateResponse (BaseModel ):
29
+ caption : str = Field (default = None , title = "Caption" , description = "The generated caption for the image." )
30
+ parameters : Json
31
+ info : Json
32
+
28
33
29
34
class Api :
30
35
def __init__ (self , app , queue_lock ):
@@ -33,6 +38,7 @@ def __init__(self, app, queue_lock):
33
38
self .queue_lock = queue_lock
34
39
self .app .add_api_route ("/sdapi/v1/txt2img" , self .text2imgapi , methods = ["POST" ])
35
40
self .app .add_api_route ("/sdapi/v1/img2img" , self .img2imgapi , methods = ["POST" ])
41
+ self .app .add_api_route ("/sdapi/v1/interrogate" , self .interrogateapi , methods = ["POST" ])
36
42
37
43
def __base64_to_image (self , base64_string ):
38
44
# if has a comma, deal with prefix
@@ -118,6 +124,23 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
118
124
119
125
return ImageToImageResponse (images = b64images , parameters = json .dumps (vars (img2imgreq )), info = processed .js ())
120
126
127
+ def interrogateapi (self , interrogatereq : InterrogateAPI ):
128
+ image_b64 = interrogatereq .image
129
+ if image_b64 is None :
130
+ raise HTTPException (status_code = 404 , detail = "Image not found" )
131
+
132
+ populate = interrogatereq .copy (update = { # Override __init__ params
133
+ }
134
+ )
135
+
136
+ img = self .__base64_to_image (image_b64 )
137
+
138
+ # Override object param
139
+ with self .queue_lock :
140
+ processed = shared .interrogator .interrogate (img )
141
+
142
+ return InterrogateResponse (caption = processed , parameters = json .dumps (vars (interrogatereq )), info = None )
143
+
121
144
def extrasapi (self ):
122
145
raise NotImplementedError
123
146
0 commit comments