Skip to content

Commit 336c341

Browse files
committed
Merge branch 'master' into api-authorization
2 parents 8f2ff86 + 84a6f21 commit 336c341

28 files changed

+187
-144
lines changed

modules/api/api.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from secrets import compare_digest
1010

1111
import modules.shared as shared
12+
from modules import sd_samplers
1213
from modules.api.models import *
1314
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
14-
from modules.sd_samplers import all_samplers
1515
from modules.extras import run_extras, run_pnginfo
1616
from PIL import PngImagePlugin
1717
from modules.sd_models import checkpoints_list
@@ -28,8 +28,12 @@ def upscaler_to_index(name: str):
2828
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
2929

3030

31-
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
31+
def validate_sampler_name(name):
32+
config = sd_samplers.all_samplers_map.get(name, None)
33+
if config is None:
34+
raise HTTPException(status_code=404, detail="Sampler not found")
3235

36+
return name
3337

3438
def setUpscalers(req: dict):
3539
reqDict = vars(req)
@@ -77,6 +81,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
7781
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
7882
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
7983
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
84+
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
8085
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
8186
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
8287
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
@@ -103,14 +108,9 @@ def auth(self, credenticals: HTTPBasicCredentials = Depends(HTTPBasic())):
103108
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
104109

105110
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
106-
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
107-
108-
if sampler_index is None:
109-
raise HTTPException(status_code=404, detail="Sampler not found")
110-
111111
populate = txt2imgreq.copy(update={ # Override __init__ params
112112
"sd_model": shared.sd_model,
113-
"sampler_index": sampler_index[0],
113+
"sampler_name": validate_sampler_name(txt2imgreq.sampler_index),
114114
"do_not_save_samples": True,
115115
"do_not_save_grid": True
116116
}
@@ -130,12 +130,6 @@ def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
130130
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
131131

132132
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
133-
sampler_index = sampler_to_index(img2imgreq.sampler_index)
134-
135-
if sampler_index is None:
136-
raise HTTPException(status_code=404, detail="Sampler not found")
137-
138-
139133
init_images = img2imgreq.init_images
140134
if init_images is None:
141135
raise HTTPException(status_code=404, detail="Init image not found")
@@ -144,10 +138,9 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
144138
if mask:
145139
mask = decode_base64_to_image(mask)
146140

147-
148141
populate = img2imgreq.copy(update={ # Override __init__ params
149142
"sd_model": shared.sd_model,
150-
"sampler_index": sampler_index[0],
143+
"sampler_name": validate_sampler_name(img2imgreq.sampler_index),
151144
"do_not_save_samples": True,
152145
"do_not_save_grid": True,
153146
"mask": mask
@@ -266,6 +259,9 @@ def interruptapi(self):
266259

267260
return {}
268261

262+
def skip(self):
263+
shared.state.skip()
264+
269265
def get_config(self):
270266
options = {}
271267
for key in shared.opts.data.keys():
@@ -277,14 +273,10 @@ def get_config(self):
277273

278274
return options
279275

280-
def set_config(self, req: OptionsModel):
281-
# currently req has all options fields even if you send a dict like { "send_seed": false }, which means it will
282-
# overwrite all options with default values.
283-
raise RuntimeError('Setting options via API is not supported')
284-
285-
reqDict = vars(req)
286-
for o in reqDict:
287-
setattr(shared.opts, o, reqDict[o])
276+
def set_config(self, req: Dict[str, Any]):
277+
278+
for o in req:
279+
setattr(shared.opts, o, req[o])
288280

289281
shared.opts.save(shared.config_filename)
290282
return
@@ -293,7 +285,7 @@ def get_cmd_flags(self):
293285
return vars(shared.cmd_opts)
294286

295287
def get_samplers(self):
296-
return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers]
288+
return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
297289

298290
def get_upscalers(self):
299291
upscalers = []

modules/api/models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,9 @@ class InterrogateResponse(BaseModel):
176176
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
177177

178178
fields = {}
179-
for key, value in opts.data.items():
180-
metadata = opts.data_labels.get(key)
181-
optType = opts.typemap.get(type(value), type(value))
179+
for key, metadata in opts.data_labels.items():
180+
value = opts.data.get(key)
181+
optType = opts.typemap.get(type(metadata.default), type(value))
182182

183183
if (metadata is not None):
184184
fields.update({key: (Optional[optType], Field(

modules/extensions.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,12 @@ def check_updates(self):
6565
self.can_update = False
6666
self.status = "latest"
6767

68-
def pull(self):
68+
def fetch_and_reset_hard(self):
6969
repo = git.Repo(self.path)
70-
repo.remotes.origin.pull()
70+
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
71+
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
72+
repo.git.fetch('--all')
73+
repo.git.reset('--hard', 'origin')
7174

7275

7376
def list_extensions():

modules/generation_parameters_copypaste.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def integrate_settings_paste_fields(component_dict):
7373
'sd_hypernetwork': 'Hypernet',
7474
'sd_hypernetwork_strength': 'Hypernet strength',
7575
'CLIP_stop_at_last_layers': 'Clip skip',
76+
'inpainting_mask_weight': 'Conditional mask weight',
7677
'sd_model_checkpoint': 'Model hash',
7778
}
7879
settings_paste_fields = [

modules/hypernetworks/hypernetwork.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import tqdm
1313
from einops import rearrange, repeat
1414
from ldm.util import default
15-
from modules import devices, processing, sd_models, shared
15+
from modules import devices, processing, sd_models, shared, sd_samplers
1616
from modules.textual_inversion import textual_inversion
1717
from modules.textual_inversion.learn_schedule import LearnRateScheduler
1818
from torch import einsum
@@ -535,7 +535,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
535535
p.prompt = preview_prompt
536536
p.negative_prompt = preview_negative_prompt
537537
p.steps = preview_steps
538-
p.sampler_index = preview_sampler_index
538+
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
539539
p.cfg_scale = preview_cfg_scale
540540
p.seed = preview_seed
541541
p.width = preview_width

modules/images.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ class FilenameGenerator:
303303
'width': lambda self: self.image.width,
304304
'height': lambda self: self.image.height,
305305
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
306-
'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False),
306+
'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
307307
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
308308
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
309309
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]

modules/img2img.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
from PIL import Image, ImageOps, ImageChops
88

9-
from modules import devices
9+
from modules import devices, sd_samplers
1010
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
1111
from modules.shared import opts, state
1212
import modules.shared as shared
@@ -99,7 +99,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
9999
seed_resize_from_h=seed_resize_from_h,
100100
seed_resize_from_w=seed_resize_from_w,
101101
seed_enable_extras=seed_enable_extras,
102-
sampler_index=sampler_index,
102+
sampler_index=sd_samplers.samplers_for_img2img[sampler_index].name,
103103
batch_size=batch_size,
104104
n_iter=n_iter,
105105
steps=steps,

0 commit comments

Comments
 (0)