Skip to content

Commit b8435e6

Browse files
committed
add --cors-allow-origins cmd opt
1 parent 89722fb commit b8435e6

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

modules/shared.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
8787
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
8888
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
89+
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origins", default=None)
8990

9091
cmd_opts = parser.parse_args()
9192
restricted_opts = {
@@ -147,9 +148,9 @@ def interrupt(self):
147148
self.interrupted = True
148149

149150
def nextjob(self):
150-
if opts.show_progress_every_n_steps == -1:
151+
if opts.show_progress_every_n_steps == -1:
151152
self.do_set_current_image()
152-
153+
153154
self.job_no += 1
154155
self.sampling_step = 0
155156
self.current_image_sampling_step = 0
@@ -198,7 +199,7 @@ def do_set_current_image(self):
198199
return
199200
if self.current_latent is None:
200201
return
201-
202+
202203
if opts.show_progress_grid:
203204
self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
204205
else:

webui.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import signal
66
import threading
77
from fastapi import FastAPI
8+
from fastapi.middleware.cors import CORSMiddleware
89
from fastapi.middleware.gzip import GZipMiddleware
910

1011
from modules.paths import script_path
@@ -93,6 +94,11 @@ def sigint_handler(sig, frame):
9394
signal.signal(signal.SIGINT, sigint_handler)
9495

9596

97+
def setup_cors(app):
98+
if cmd_opts.cors_allow_origins:
99+
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'])
100+
101+
96102
def create_api(app):
97103
from modules.api.api import Api
98104
api = Api(app, queue_lock)
@@ -114,6 +120,7 @@ def api_only():
114120
initialize()
115121

116122
app = FastAPI()
123+
setup_cors(app)
117124
app.add_middleware(GZipMiddleware, minimum_size=1000)
118125
api = create_api(app)
119126

@@ -147,6 +154,8 @@ def webui():
147154
# runnnig its code. We disable this here. Suggested by RyotaK.
148155
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
149156

157+
setup_cors(app)
158+
150159
app.add_middleware(GZipMiddleware, minimum_size=1000)
151160

152161
if launch_api:

0 commit comments

Comments
 (0)