Skip to content

Commit f407aa3

Browse files
committed
emulated oai image generation
1 parent cdda9d1 commit f407aa3

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

koboldcpp.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1639,6 +1639,18 @@ def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl
16391639
ret = handle.sd_load_model(inputs)
16401640
return ret
16411641

1642+
def sd_oai_tranform_params(genparams):
1643+
size = genparams.get('size', "512x512")
1644+
if size and size!="":
1645+
pattern = r'^\D*(\d+)x(\d+)$'
1646+
match = re.fullmatch(pattern, size)
1647+
if match:
1648+
width = int(match.group(1))
1649+
height = int(match.group(2))
1650+
genparams["width"] = width
1651+
genparams["height"] = height
1652+
return genparams
1653+
16421654
def sd_comfyui_tranform_params(genparams):
16431655
promptobj = genparams.get('prompt', None)
16441656
if promptobj and isinstance(promptobj, dict):
@@ -3700,6 +3712,7 @@ def do_POST(self):
37003712
api_format = 0 #1=basic,2=kai,3=oai,4=oai-chat,5=interrogate,6=ollama,7=ollamachat
37013713
is_imggen = False
37023714
is_comfyui_imggen = False
3715+
is_oai_imggen = False
37033716
is_transcribe = False
37043717
is_tts = False
37053718
is_embeddings = False
@@ -3785,10 +3798,12 @@ def do_POST(self):
37853798
api_format = 6
37863799
elif self.path.endswith('/api/chat'): #ollama
37873800
api_format = 7
3788-
elif self.path=="/prompt" or self.path.endswith('/sdapi/v1/txt2img') or self.path.endswith('/sdapi/v1/img2img'):
3801+
elif self.path=="/prompt" or self.path.endswith('/v1/images/generations') or self.path.endswith('/sdapi/v1/txt2img') or self.path.endswith('/sdapi/v1/img2img'):
37893802
is_imggen = True
37903803
if self.path=="/prompt":
37913804
is_comfyui_imggen = True
3805+
elif self.path.endswith('/v1/images/generations'):
3806+
is_oai_imggen = True
37923807
elif self.path.endswith('/api/extra/transcribe') or self.path.endswith('/v1/audio/transcriptions'):
37933808
is_transcribe = True
37943809
elif self.path.endswith('/api/extra/tts') or self.path.endswith('/v1/audio/speech') or self.path.endswith('/tts_to_audio'):
@@ -3898,6 +3913,8 @@ def do_POST(self):
38983913
if is_comfyui_imggen:
38993914
lastgeneratedcomfyimg = b''
39003915
genparams = sd_comfyui_tranform_params(genparams)
3916+
elif is_oai_imggen:
3917+
genparams = sd_oai_tranform_params(genparams)
39013918
gen = sd_generate(genparams)
39023919
genresp = None
39033920
if is_comfyui_imggen:
@@ -3906,6 +3923,8 @@ def do_POST(self):
39063923
else:
39073924
lastgeneratedcomfyimg = b''
39083925
genresp = (json.dumps({"prompt_id": "12345678-0000-0000-0000-000000000001","number": 0,"node_errors":{}}).encode())
3926+
elif is_oai_imggen:
3927+
genresp = (json.dumps({"created":int(time.time()),"data":[{"b64_json":gen}],"background":"opaque","output_format":"png","size":"1024x1024","quality":"medium"}).encode())
39093928
else:
39103929
genresp = (json.dumps({"images":[gen],"parameters":{},"info":""}).encode())
39113930
self.send_response(200)

0 commit comments

Comments
 (0)