Skip to content

Commit cba1d79

Browse files
committed
Add multithreaded capability to WASApplyLUT
1 parent 027832d commit cba1d79

File tree

3 files changed

+80
-7
lines changed

3 files changed

+80
-7
lines changed

__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import time
44
import traceback
55

6-
# Rich (optional)
6+
77
try:
88
from rich.console import Console
99
from rich.table import Table

nodes/WASLUT.py

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,41 @@
77

88
from pathlib import Path
99
from PIL import Image, ImageDraw, ImageFont
10+
from concurrent.futures import ThreadPoolExecutor, as_completed
11+
import threading
12+
13+
try:
14+
from comfy.utils import ProgressBar as ComfyProgressBar
15+
except Exception:
16+
from tqdm import tqdm
17+
18+
class ComfyProgressBar:
19+
def __init__(self, total=None):
20+
try:
21+
self._bar = tqdm(total=total, desc="Applying LUT", unit="frame")
22+
except Exception:
23+
self._bar = None
24+
25+
def update(self, n: int = 1):
26+
try:
27+
if self._bar is not None:
28+
self._bar.update(n)
29+
except Exception:
30+
pass
31+
32+
def close(self):
33+
try:
34+
if self._bar is not None:
35+
self._bar.close()
36+
except Exception:
37+
pass
38+
39+
def __del__(self):
40+
try:
41+
if self._bar is not None:
42+
self._bar.close()
43+
except Exception:
44+
pass
1045

1146
try:
1247
from folder_paths import folder_names_and_paths
@@ -962,6 +997,8 @@ def INPUT_TYPES(cls):
962997
"image": ("IMAGE",),
963998
"lut": ("LUT",),
964999
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
1000+
"use_threads": ("BOOLEAN", {"default": False}),
1001+
"threads": ("INT", {"default": 0, "min": 0, "max": 64, "step": 1}),
9651002
}
9661003
}
9671004

@@ -971,13 +1008,49 @@ def INPUT_TYPES(cls):
9711008
FUNCTION = "run"
9721009
CATEGORY = "WAS/Color/LUT"
9731010

974-
def run(self, image, lut, strength):
1011+
def run(self, image, lut, strength, use_threads=False, threads=0):
9751012
size = lut.size() if lut.size() > 1 else 33
9761013
lut3 = WASLUT.convert_to_3d(lut, size)
977-
y = WASLUT.apply_lut_3d(image, lut3.table_3d, lut3.domain_min, lut3.domain_max).clamp(0, 1)
978-
if strength < 1.0:
979-
y = image * (1.0 - strength) + y * strength
980-
return (y.clamp(0, 1),)
1014+
1015+
b = int(image.shape[0])
1016+
pb = ComfyProgressBar(total=b)
1017+
1018+
if (not use_threads) or b <= 1:
1019+
y = WASLUT.apply_lut_3d(image, lut3.table_3d, lut3.domain_min, lut3.domain_max).clamp(0, 1)
1020+
if strength < 1.0:
1021+
y = image * (1.0 - strength) + y * strength
1022+
pb.update(b)
1023+
return (y.clamp(0, 1),)
1024+
1025+
try:
1026+
cpu_cnt = os.cpu_count() or 1
1027+
except Exception:
1028+
cpu_cnt = 1
1029+
max_workers = int(threads) if int(threads) > 0 else min(cpu_cnt, b)
1030+
max_workers = max(1, min(max_workers, 64))
1031+
1032+
frames = [(i, image[i:i+1]) for i in range(b)]
1033+
results = [None] * b
1034+
lock = threading.Lock()
1035+
1036+
def work(item):
1037+
i, frame = item
1038+
yi = WASLUT.apply_lut_3d(frame, lut3.table_3d, lut3.domain_min, lut3.domain_max).clamp(0, 1)
1039+
if strength < 1.0:
1040+
yi = frame * (1.0 - strength) + yi * strength
1041+
yi = yi.clamp(0, 1)
1042+
with lock:
1043+
pb.update(1)
1044+
return i, yi
1045+
1046+
with ThreadPoolExecutor(max_workers=max_workers) as ex:
1047+
futs = [ex.submit(work, item) for item in frames]
1048+
for f in as_completed(futs):
1049+
i, yi = f.result()
1050+
results[i] = yi
1051+
1052+
y = torch.cat(results, dim=0)
1053+
return (y,)
9811054

9821055
# Save LUT
9831056

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "was-extras"
3-
version = "1.0.0"
3+
version = "1.0.1"
44
description = "A collection of experimental WAS nodes and utilities for ComfyUI."
55
readme = "README.md"
66
requires-python = ">=3.10"

0 commit comments

Comments
 (0)