3
3
import time
4
4
import uvicorn
5
5
from threading import Lock
6
+ from io import BytesIO
6
7
from gradio .processing_utils import encode_pil_to_base64 , decode_base64_to_file , decode_base64_to_image
7
8
from fastapi import APIRouter , Depends , FastAPI , HTTPException
8
9
from fastapi .security import HTTPBasic , HTTPBasicCredentials
13
14
from modules .api .models import *
14
15
from modules .processing import StableDiffusionProcessingTxt2Img , StableDiffusionProcessingImg2Img , process_images
15
16
from modules .extras import run_extras , run_pnginfo
16
- from PIL import PngImagePlugin
17
+ from PIL import PngImagePlugin , Image
17
18
from modules .sd_models import checkpoints_list
18
19
from modules .realesrgan_model import get_realesrgan_models
19
20
from typing import List
@@ -133,7 +134,10 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
133
134
134
135
mask = img2imgreq .mask
135
136
if mask :
136
- mask = decode_base64_to_image (mask )
137
+ if mask .startswith ("data:image/" ):
138
+ mask = decode_base64_to_image (mask )
139
+ else :
140
+ mask = Image .open (BytesIO (base64 .b64decode (mask )))
137
141
138
142
populate = img2imgreq .copy (update = { # Override __init__ params
139
143
"sd_model" : shared .sd_model ,
@@ -147,7 +151,10 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
147
151
148
152
imgs = []
149
153
for img in init_images :
150
- img = decode_base64_to_image (img )
154
+ if img .startswith ("data:image/" ):
155
+ img = decode_base64_to_image (img )
156
+ else :
157
+ img = Image .open (BytesIO (base64 .b64decode (img )))
151
158
imgs = [img ] * p .batch_size
152
159
153
160
p .init_images = imgs
0 commit comments