Skip to content

Commit 994136b

Browse files
Merge pull request #4294 from evshiron/feat/allow-origins
add --cors-allow-origins cmd opt
2 parents c9b2eef + 37ba007 commit 994136b

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

modules/shared.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
8787
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
8888
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
89+
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origins", default=None)
8990
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
9091
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
9192
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
@@ -150,9 +151,9 @@ def interrupt(self):
150151
self.interrupted = True
151152

152153
def nextjob(self):
153-
if opts.show_progress_every_n_steps == -1:
154+
if opts.show_progress_every_n_steps == -1:
154155
self.do_set_current_image()
155-
156+
156157
self.job_no += 1
157158
self.sampling_step = 0
158159
self.current_image_sampling_step = 0
@@ -201,7 +202,7 @@ def do_set_current_image(self):
201202
return
202203
if self.current_latent is None:
203204
return
204-
205+
205206
if opts.show_progress_grid:
206207
self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
207208
else:

webui.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import signal
66
import threading
77
from fastapi import FastAPI
8+
from fastapi.middleware.cors import CORSMiddleware
89
from fastapi.middleware.gzip import GZipMiddleware
910

1011
from modules.paths import script_path
@@ -107,6 +108,11 @@ def sigint_handler(sig, frame):
107108
signal.signal(signal.SIGINT, sigint_handler)
108109

109110

111+
def setup_cors(app):
112+
if cmd_opts.cors_allow_origins:
113+
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'])
114+
115+
110116
def create_api(app):
111117
from modules.api.api import Api
112118
api = Api(app, queue_lock)
@@ -128,6 +134,7 @@ def api_only():
128134
initialize()
129135

130136
app = FastAPI()
137+
setup_cors(app)
131138
app.add_middleware(GZipMiddleware, minimum_size=1000)
132139
api = create_api(app)
133140

@@ -163,6 +170,8 @@ def webui():
163170
# runnnig its code. We disable this here. Suggested by RyotaK.
164171
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
165172

173+
setup_cors(app)
174+
166175
app.add_middleware(GZipMiddleware, minimum_size=1000)
167176

168177
if launch_api:

0 commit comments

Comments
 (0)