Skip to content

Commit 5302e2c

Browse files
Merge pull request #3810 from royshil/roy.add_simple_interrogate_api
Add a barebones CLIP interrogate API endpoint
2 parents 6e4de5b + 07d1bd4 commit 5302e2c

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

modules/api/api.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
6363
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
6464
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
6565
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
66+
self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
6667
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
6768
self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
6869
self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
@@ -214,6 +215,19 @@ def progressapi(self, req: ProgressRequest = Depends()):
214215

215216
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
216217

218+
def interrogateapi(self, interrogatereq: InterrogateRequest):
219+
image_b64 = interrogatereq.image
220+
if image_b64 is None:
221+
raise HTTPException(status_code=404, detail="Image not found")
222+
223+
img = self.__base64_to_image(image_b64)
224+
225+
# Override object param
226+
with self.queue_lock:
227+
processed = shared.interrogator.interrogate(img)
228+
229+
return InterrogateResponse(caption=processed)
230+
217231
def interruptapi(self):
218232
shared.state.interrupt()
219233

modules/api/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def merge_class_params(class_):
6565

6666
self._model_name = model_name
6767
self._class_data = merge_class_params(class_instance)
68+
6869
self._model_def = [
6970
ModelDef(
7071
field=underscore(k),
@@ -167,6 +168,12 @@ class ProgressResponse(BaseModel):
167168
state: dict = Field(title="State", description="The current state snapshot")
168169
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
169170

171+
class InterrogateRequest(BaseModel):
172+
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
173+
174+
class InterrogateResponse(BaseModel):
175+
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
176+
170177
fields = {}
171178
for key, value in opts.data.items():
172179
metadata = opts.data_labels.get(key)
@@ -231,3 +238,4 @@ class ArtistItem(BaseModel):
231238
name: str = Field(title="Name")
232239
score: float = Field(title="Score")
233240
category: str = Field(title="Category")
241+

0 commit comments

Comments
 (0)