Skip to content

Commit 1497842

Browse files
committed
rework #3722 to not introduce duplicate code
1 parent 060ee5d commit 1497842

File tree

3 files changed

+35
-49
lines changed

3 files changed

+35
-49
lines changed

modules/api/api.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,40 +9,17 @@
99
from modules.sd_samplers import all_samplers
1010
from modules.extras import run_extras, run_pnginfo
1111

12-
# copy from wrap_gradio_gpu_call of webui.py
13-
# because queue lock will be acquired in api handlers
14-
# and time start needs to be set
15-
# the function has been modified into two parts
16-
17-
def before_gpu_call():
18-
devices.torch_gc()
19-
20-
shared.state.sampling_step = 0
21-
shared.state.job_count = -1
22-
shared.state.job_no = 0
23-
shared.state.job_timestamp = shared.state.get_job_timestamp()
24-
shared.state.current_latent = None
25-
shared.state.current_image = None
26-
shared.state.current_image_sampling_step = 0
27-
shared.state.skipped = False
28-
shared.state.interrupted = False
29-
shared.state.textinfo = None
30-
shared.state.time_start = time.time()
31-
32-
def after_gpu_call():
33-
shared.state.job = ""
34-
shared.state.job_count = 0
35-
36-
devices.torch_gc()
3712

3813
def upscaler_to_index(name: str):
3914
try:
4015
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
4116
except:
4217
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
4318

19+
4420
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
4521

22+
4623
def setUpscalers(req: dict):
4724
reqDict = vars(req)
4825
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
@@ -51,6 +28,7 @@ def setUpscalers(req: dict):
5128
reqDict.pop('upscaler_2')
5229
return reqDict
5330

31+
5432
class Api:
5533
def __init__(self, app, queue_lock):
5634
self.router = APIRouter()
@@ -78,10 +56,13 @@ def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
7856
)
7957
p = StableDiffusionProcessingTxt2Img(**vars(populate))
8058
# Override object param
81-
before_gpu_call()
59+
60+
shared.state.begin()
61+
8262
with self.queue_lock:
8363
processed = process_images(p)
84-
after_gpu_call()
64+
65+
shared.state.end()
8566

8667
b64images = list(map(encode_pil_to_base64, processed.images))
8768

@@ -119,11 +100,13 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
119100
imgs = [img] * p.batch_size
120101

121102
p.init_images = imgs
122-
# Override object param
123-
before_gpu_call()
103+
104+
shared.state.begin()
105+
124106
with self.queue_lock:
125107
processed = process_images(p)
126-
after_gpu_call()
108+
109+
shared.state.end()
127110

128111
b64images = list(map(encode_pil_to_base64, processed.images))
129112

modules/shared.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,6 @@ def nextjob(self):
144144
self.sampling_step = 0
145145
self.current_image_sampling_step = 0
146146

147-
def get_job_timestamp(self):
148-
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
149-
150147
def dict(self):
151148
obj = {
152149
"skipped": self.skipped,
@@ -160,6 +157,25 @@ def dict(self):
160157

161158
return obj
162159

160+
def begin(self):
161+
self.sampling_step = 0
162+
self.job_count = -1
163+
self.job_no = 0
164+
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
165+
self.current_latent = None
166+
self.current_image = None
167+
self.current_image_sampling_step = 0
168+
self.skipped = False
169+
self.interrupted = False
170+
self.textinfo = None
171+
172+
devices.torch_gc()
173+
174+
def end(self):
175+
self.job = ""
176+
self.job_count = 0
177+
178+
devices.torch_gc()
163179

164180
state = State()
165181

webui.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,26 +46,13 @@ def f(*args, **kwargs):
4646

4747
def wrap_gradio_gpu_call(func, extra_outputs=None):
4848
def f(*args, **kwargs):
49-
devices.torch_gc()
50-
51-
shared.state.sampling_step = 0
52-
shared.state.job_count = -1
53-
shared.state.job_no = 0
54-
shared.state.job_timestamp = shared.state.get_job_timestamp()
55-
shared.state.current_latent = None
56-
shared.state.current_image = None
57-
shared.state.current_image_sampling_step = 0
58-
shared.state.skipped = False
59-
shared.state.interrupted = False
60-
shared.state.textinfo = None
49+
50+
shared.state.begin()
6151

6252
with queue_lock:
6353
res = func(*args, **kwargs)
6454

65-
shared.state.job = ""
66-
shared.state.job_count = 0
67-
68-
devices.torch_gc()
55+
shared.state.end()
6956

7057
return res
7158

0 commit comments

Comments
 (0)