77
88from pathlib import Path
99from PIL import Image , ImageDraw , ImageFont
10+ from concurrent .futures import ThreadPoolExecutor , as_completed
11+ import threading
12+
13+ try :
14+ from comfy .utils import ProgressBar as ComfyProgressBar
15+ except Exception :
16+ from tqdm import tqdm
17+
18+ class ComfyProgressBar :
19+ def __init__ (self , total = None ):
20+ try :
21+ self ._bar = tqdm (total = total , desc = "Applying LUT" , unit = "frame" )
22+ except Exception :
23+ self ._bar = None
24+
25+ def update (self , n : int = 1 ):
26+ try :
27+ if self ._bar is not None :
28+ self ._bar .update (n )
29+ except Exception :
30+ pass
31+
32+ def close (self ):
33+ try :
34+ if self ._bar is not None :
35+ self ._bar .close ()
36+ except Exception :
37+ pass
38+
39+ def __del__ (self ):
40+ try :
41+ if self ._bar is not None :
42+ self ._bar .close ()
43+ except Exception :
44+ pass
1045
1146try :
1247 from folder_paths import folder_names_and_paths
@@ -962,6 +997,8 @@ def INPUT_TYPES(cls):
962997 "image" : ("IMAGE" ,),
963998 "lut" : ("LUT" ,),
964999 "strength" : ("FLOAT" , {"default" : 1.0 , "min" : 0.0 , "max" : 1.0 , "step" : 0.01 }),
1000+ "use_threads" : ("BOOLEAN" , {"default" : False }),
1001+ "threads" : ("INT" , {"default" : 0 , "min" : 0 , "max" : 64 , "step" : 1 }),
9651002 }
9661003 }
9671004
@@ -971,13 +1008,49 @@ def INPUT_TYPES(cls):
9711008 FUNCTION = "run"
9721009 CATEGORY = "WAS/Color/LUT"
9731010
974- def run (self , image , lut , strength ):
1011+ def run (self , image , lut , strength , use_threads = False , threads = 0 ):
9751012 size = lut .size () if lut .size () > 1 else 33
9761013 lut3 = WASLUT .convert_to_3d (lut , size )
977- y = WASLUT .apply_lut_3d (image , lut3 .table_3d , lut3 .domain_min , lut3 .domain_max ).clamp (0 , 1 )
978- if strength < 1.0 :
979- y = image * (1.0 - strength ) + y * strength
980- return (y .clamp (0 , 1 ),)
1014+
1015+ b = int (image .shape [0 ])
1016+ pb = ComfyProgressBar (total = b )
1017+
1018+ if (not use_threads ) or b <= 1 :
1019+ y = WASLUT .apply_lut_3d (image , lut3 .table_3d , lut3 .domain_min , lut3 .domain_max ).clamp (0 , 1 )
1020+ if strength < 1.0 :
1021+ y = image * (1.0 - strength ) + y * strength
1022+ pb .update (b )
1023+ return (y .clamp (0 , 1 ),)
1024+
1025+ try :
1026+ cpu_cnt = os .cpu_count () or 1
1027+ except Exception :
1028+ cpu_cnt = 1
1029+ max_workers = int (threads ) if int (threads ) > 0 else min (cpu_cnt , b )
1030+ max_workers = max (1 , min (max_workers , 64 ))
1031+
1032+ frames = [(i , image [i :i + 1 ]) for i in range (b )]
1033+ results = [None ] * b
1034+ lock = threading .Lock ()
1035+
1036+ def work (item ):
1037+ i , frame = item
1038+ yi = WASLUT .apply_lut_3d (frame , lut3 .table_3d , lut3 .domain_min , lut3 .domain_max ).clamp (0 , 1 )
1039+ if strength < 1.0 :
1040+ yi = frame * (1.0 - strength ) + yi * strength
1041+ yi = yi .clamp (0 , 1 )
1042+ with lock :
1043+ pb .update (1 )
1044+ return i , yi
1045+
1046+ with ThreadPoolExecutor (max_workers = max_workers ) as ex :
1047+ futs = [ex .submit (work , item ) for item in frames ]
1048+ for f in as_completed (futs ):
1049+ i , yi = f .result ()
1050+ results [i ] = yi
1051+
1052+ y = torch .cat (results , dim = 0 )
1053+ return (y ,)
9811054
9821055# Save LUT
9831056
0 commit comments