24
24
from tqdm import tqdm
25
25
26
26
import folder_paths
27
- from comfy_extras .nodes_clip_sdxl import CLIPTextEncodeSDXL , CLIPTextEncodeSDXLRefiner
28
27
import comfy_extras
29
- from comfy_extras . nodes_upscale_model import UpscaleModelLoader , ImageUpscaleWithModel
30
- from comfy . model_management import soft_empty_cache , free_memory , get_torch_device , current_loaded_models , load_model_gpu
28
+ from spandrel import ModelLoader , ImageModelDescriptor
29
+ from comfy import model_management
31
30
from nodes import LoraLoader , ConditioningAverage , common_ksampler , ImageScale , ImageScaleBy , VAEEncode , VAEDecode
32
31
import comfy .utils
33
- from comfy_extras .chainner_models import model_loading
34
- from comfy_extras .nodes_custom_sampler import Noise_EmptyNoise , Noise_RandomNoise
35
- from comfy import model_management , model_base
32
+
33
+ # region COMFY_EXTRAS
34
+ def CLIPTextEncodeSDXLRefiner_encode (clip , ascore , width , height , text ):
35
+ tokens = clip .tokenize (text )
36
+ return (clip .encode_from_tokens_scheduled (tokens , add_dict = {"aesthetic_score" : ascore , "width" : width , "height" : height }), )
37
+
38
+ def CLIPTextEncodeSDXL_encode (clip , width , height , crop_w , crop_h , target_width , target_height , text_g , text_l ):
39
+ tokens = clip .tokenize (text_g )
40
+ tokens ["l" ] = clip .tokenize (text_l )["l" ]
41
+ if len (tokens ["l" ]) != len (tokens ["g" ]):
42
+ empty = clip .tokenize ("" )
43
+ while len (tokens ["l" ]) < len (tokens ["g" ]):
44
+ tokens ["l" ] += empty ["l" ]
45
+ while len (tokens ["l" ]) > len (tokens ["g" ]):
46
+ tokens ["g" ] += empty ["g" ]
47
+ return (clip .encode_from_tokens_scheduled (tokens , add_dict = {"width" : width , "height" : height , "crop_w" : crop_w , "crop_h" : crop_h , "target_width" : target_width , "target_height" : target_height }), )
48
+
49
+ class UpscaleModelLoader :
50
+ @classmethod
51
+ def INPUT_TYPES (s ):
52
+ return {"required" : { "model_name" : (folder_paths .get_filename_list ("upscale_models" ), ),
53
+ }}
54
+ RETURN_TYPES = ("UPSCALE_MODEL" ,)
55
+ FUNCTION = "load_model"
56
+
57
+ CATEGORY = "loaders"
58
+
59
+ def load_model (self , model_name ):
60
+ model_path = folder_paths .get_full_path_or_raise ("upscale_models" , model_name )
61
+ sd = comfy .utils .load_torch_file (model_path , safe_load = True )
62
+ if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd :
63
+ sd = comfy .utils .state_dict_prefix_replace (sd , {"module." :"" })
64
+ out = ModelLoader ().load_from_state_dict (sd ).eval ()
65
+
66
+ if not isinstance (out , ImageModelDescriptor ):
67
+ raise Exception ("Upscale model must be a single-image model." )
68
+
69
+ return (out , )
70
+
71
+ class ImageUpscaleWithModel :
72
+ @classmethod
73
+ def INPUT_TYPES (s ):
74
+ return {"required" : { "upscale_model" : ("UPSCALE_MODEL" ,),
75
+ "image" : ("IMAGE" ,),
76
+ }}
77
+ RETURN_TYPES = ("IMAGE" ,)
78
+ FUNCTION = "upscale"
79
+
80
+ CATEGORY = "image/upscaling"
81
+
82
+ def upscale (self , upscale_model , image ):
83
+ device = model_management .get_torch_device ()
84
+
85
+ memory_required = model_management .module_size (upscale_model .model )
86
+ memory_required += (512 * 512 * 3 ) * image .element_size () * max (upscale_model .scale , 1.0 ) * 384.0 #The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate
87
+ memory_required += image .nelement () * image .element_size ()
88
+ model_management .free_memory (memory_required , device )
89
+
90
+ upscale_model .to (device )
91
+ in_img = image .movedim (- 1 ,- 3 ).to (device )
92
+
93
+ tile = 512
94
+ overlap = 32
95
+
96
+ oom = True
97
+ while oom :
98
+ try :
99
+ steps = in_img .shape [0 ] * comfy .utils .get_tiled_scale_steps (in_img .shape [3 ], in_img .shape [2 ], tile_x = tile , tile_y = tile , overlap = overlap )
100
+ pbar = comfy .utils .ProgressBar (steps )
101
+ s = comfy .utils .tiled_scale (in_img , lambda a : upscale_model (a ), tile_x = tile , tile_y = tile , overlap = overlap , upscale_amount = upscale_model .scale , pbar = pbar )
102
+ oom = False
103
+ except model_management .OOM_EXCEPTION as e :
104
+ tile //= 2
105
+ if tile < 128 :
106
+ raise e
107
+
108
+ upscale_model .to ("cpu" )
109
+ s = torch .clamp (s .movedim (- 3 ,- 1 ), min = 0 , max = 1.0 )
110
+ return (s ,)
111
+
112
+ class Noise_EmptyNoise :
113
+ def __init__ (self ):
114
+ self .seed = 0
115
+
116
+ def generate_noise (self , input_latent ):
117
+ latent_image = input_latent ["samples" ]
118
+ return torch .zeros (latent_image .shape , dtype = latent_image .dtype , layout = latent_image .layout , device = "cpu" )
119
+
120
+ class Noise_RandomNoise :
121
+ def __init__ (self , seed ):
122
+ self .seed = seed
123
+
124
+ def generate_noise (self , input_latent ):
125
+ latent_image = input_latent ["samples" ]
126
+ batch_inds = input_latent ["batch_index" ] if "batch_index" in input_latent else None
127
+ return comfy .sample .prepare_noise (latent_image , self .seed , batch_inds )
128
+ # end region
36
129
37
130
def calculate_file_hash (file_path ):
38
131
# open the file in binary mode
@@ -1816,6 +1909,8 @@ def INPUT_TYPES(s):
1816
1909
FUNCTION = 'start'
1817
1910
CATEGORY = 'Mikey'
1818
1911
1912
+
1913
+
1819
1914
def start (self , clip_base , clip_refiner , positive_prompt , negative_prompt , style , ratio_selected , batch_size , seed ):
1820
1915
""" get output from PromptWithStyle.start """
1821
1916
(latent ,
@@ -1835,10 +1930,10 @@ def start(self, clip_base, clip_refiner, positive_prompt, negative_prompt, style
1835
1930
# 'Target Width:', target_width, 'Target Height:', target_height,
1836
1931
# 'Refiner Width:', refiner_width, 'Refiner Height:', refiner_height)
1837
1932
# encode text
1838
- sdxl_pos_cond = CLIPTextEncodeSDXL . encode ( self , clip_base , width , height , 0 , 0 , target_width , target_height , pos_prompt , pos_style )[0 ]
1839
- sdxl_neg_cond = CLIPTextEncodeSDXL . encode ( self , clip_base , width , height , 0 , 0 , target_width , target_height , neg_prompt , neg_style )[0 ]
1840
- refiner_pos_cond = CLIPTextEncodeSDXLRefiner . encode ( self , clip_refiner , 6 , refiner_width , refiner_height , pos_prompt )[0 ]
1841
- refiner_neg_cond = CLIPTextEncodeSDXLRefiner . encode ( self , clip_refiner , 2.5 , refiner_width , refiner_height , neg_prompt )[0 ]
1933
+ sdxl_pos_cond = CLIPTextEncodeSDXL_encode ( clip_base , width , height , 0 , 0 , target_width , target_height , pos_prompt , pos_style )[0 ]
1934
+ sdxl_neg_cond = CLIPTextEncodeSDXL_encode ( clip_base , width , height , 0 , 0 , target_width , target_height , neg_prompt , neg_style )[0 ]
1935
+ refiner_pos_cond = CLIPTextEncodeSDXLRefiner_encode ( clip_refiner , 6 , refiner_width , refiner_height , pos_prompt )[0 ]
1936
+ refiner_neg_cond = CLIPTextEncodeSDXLRefiner_encode ( clip_refiner , 2.5 , refiner_width , refiner_height , neg_prompt )[0 ]
1842
1937
# return
1843
1938
return (latent ,
1844
1939
sdxl_pos_cond , sdxl_neg_cond ,
@@ -2117,10 +2212,10 @@ def start(self, base_model, clip_base, clip_refiner, positive_prompt, negative_p
2117
2212
# encode text
2118
2213
add_metadata_to_dict (prompt_with_style , style = style_ , clip_g_positive = pos_prompt , clip_l_positive = pos_style_ )
2119
2214
add_metadata_to_dict (prompt_with_style , clip_g_negative = neg_prompt , clip_l_negative = neg_style_ )
2120
- sdxl_pos_cond = CLIPTextEncodeSDXL . encode ( self , clip_base_pos , width , height , 0 , 0 , target_width , target_height , pos_prompt_ , pos_style_ )[0 ]
2121
- sdxl_neg_cond = CLIPTextEncodeSDXL . encode ( self , clip_base_neg , width , height , 0 , 0 , target_width , target_height , neg_prompt_ , neg_style_ )[0 ]
2122
- refiner_pos_cond = CLIPTextEncodeSDXLRefiner . encode ( self , clip_refiner , 6 , refiner_width , refiner_height , pos_prompt_ )[0 ]
2123
- refiner_neg_cond = CLIPTextEncodeSDXLRefiner . encode ( self , clip_refiner , 2.5 , refiner_width , refiner_height , neg_prompt_ )[0 ]
2215
+ sdxl_pos_cond = CLIPTextEncodeSDXL_encode ( clip_base_pos , width , height , 0 , 0 , target_width , target_height , pos_prompt_ , pos_style_ )[0 ]
2216
+ sdxl_neg_cond = CLIPTextEncodeSDXL_encode ( clip_base_neg , width , height , 0 , 0 , target_width , target_height , neg_prompt_ , neg_style_ )[0 ]
2217
+ refiner_pos_cond = CLIPTextEncodeSDXLRefiner_encode ( clip_refiner , 6 , refiner_width , refiner_height , pos_prompt_ )[0 ]
2218
+ refiner_neg_cond = CLIPTextEncodeSDXLRefiner_encode ( clip_refiner , 2.5 , refiner_width , refiner_height , neg_prompt_ )[0 ]
2124
2219
#prompt.get(str(unique_id))['inputs']['output_positive_prompt'] = pos_prompt_
2125
2220
#prompt.get(str(unique_id))['inputs']['output_negative_prompt'] = neg_prompt_
2126
2221
#prompt.get(str(unique_id))['inputs']['output_latent_width'] = width
@@ -2166,10 +2261,10 @@ def start(self, base_model, clip_base, clip_refiner, positive_prompt, negative_p
2166
2261
# encode text
2167
2262
add_metadata_to_dict (prompt_with_style , style = style_ , clip_g_positive = pos_prompt_ , clip_l_positive = pos_style_ )
2168
2263
add_metadata_to_dict (prompt_with_style , clip_g_negative = neg_prompt_ , clip_l_negative = neg_style_ )
2169
- base_pos_conds .append (CLIPTextEncodeSDXL . encode ( self , clip_base_pos , width_ , height_ , 0 , 0 , target_width , target_height , pos_prompt_ , pos_style_ )[0 ])
2170
- base_neg_conds .append (CLIPTextEncodeSDXL . encode ( self , clip_base_neg , width_ , height_ , 0 , 0 , target_width , target_height , neg_prompt_ , neg_style_ )[0 ])
2171
- refiner_pos_conds .append (CLIPTextEncodeSDXLRefiner . encode ( self , clip_refiner , 6 , refiner_width_ , refiner_height_ , pos_prompt_ )[0 ])
2172
- refiner_neg_conds .append (CLIPTextEncodeSDXLRefiner . encode ( self , clip_refiner , 2.5 , refiner_width_ , refiner_height_ , neg_prompt_ )[0 ])
2264
+ base_pos_conds .append (CLIPTextEncodeSDXL_encode ( clip_base_pos , width_ , height_ , 0 , 0 , target_width , target_height , pos_prompt_ , pos_style_ )[0 ])
2265
+ base_neg_conds .append (CLIPTextEncodeSDXL_encode ( clip_base_neg , width_ , height_ , 0 , 0 , target_width , target_height , neg_prompt_ , neg_style_ )[0 ])
2266
+ refiner_pos_conds .append (CLIPTextEncodeSDXLRefiner_encode ( clip_refiner , 6 , refiner_width_ , refiner_height_ , pos_prompt_ )[0 ])
2267
+ refiner_neg_conds .append (CLIPTextEncodeSDXLRefiner_encode ( clip_refiner , 2.5 , refiner_width_ , refiner_height_ , neg_prompt_ )[0 ])
2173
2268
# if none of the styles matched we will get an empty list so we need to check for that again
2174
2269
if len (base_pos_conds ) == 0 :
2175
2270
style_ = 'none'
@@ -2180,10 +2275,10 @@ def start(self, base_model, clip_base, clip_refiner, positive_prompt, negative_p
2180
2275
# encode text
2181
2276
add_metadata_to_dict (prompt_with_style , style = style_ , clip_g_positive = pos_prompt_ , clip_l_positive = pos_style_ )
2182
2277
add_metadata_to_dict (prompt_with_style , clip_g_negative = neg_prompt_ , clip_l_negative = neg_style_ )
2183
- sdxl_pos_cond = CLIPTextEncodeSDXL . encode ( self , clip_base_pos , width , height , 0 , 0 , target_width , target_height , pos_prompt_ , pos_style_ )[0 ]
2184
- sdxl_neg_cond = CLIPTextEncodeSDXL . encode ( self , clip_base_neg , width , height , 0 , 0 , target_width , target_height , neg_prompt_ , neg_style_ )[0 ]
2185
- refiner_pos_cond = CLIPTextEncodeSDXLRefiner . encode ( self , clip_refiner , 6 , refiner_width , refiner_height , pos_prompt_ )[0 ]
2186
- refiner_neg_cond = CLIPTextEncodeSDXLRefiner . encode ( self , clip_refiner , 2.5 , refiner_width , refiner_height , neg_prompt_ )[0 ]
2278
+ sdxl_pos_cond = CLIPTextEncodeSDXL_encode ( clip_base_pos , width , height , 0 , 0 , target_width , target_height , pos_prompt_ , pos_style_ )[0 ]
2279
+ sdxl_neg_cond = CLIPTextEncodeSDXL_encode ( clip_base_neg , width , height , 0 , 0 , target_width , target_height , neg_prompt_ , neg_style_ )[0 ]
2280
+ refiner_pos_cond = CLIPTextEncodeSDXLRefiner_encode ( clip_refiner , 6 , refiner_width , refiner_height , pos_prompt_ )[0 ]
2281
+ refiner_neg_cond = CLIPTextEncodeSDXLRefiner_encode ( clip_refiner , 2.5 , refiner_width , refiner_height , neg_prompt_ )[0 ]
2187
2282
#prompt.get(str(unique_id))['inputs']['output_positive_prompt'] = pos_prompt_
2188
2283
#prompt.get(str(unique_id))['inputs']['output_negative_prompt'] = neg_prompt_
2189
2284
#prompt.get(str(unique_id))['inputs']['output_latent_width'] = width
@@ -2388,10 +2483,10 @@ def add_style(self, style, strength, positive_cond_base, negative_cond_base,
2388
2483
if style == 'none' :
2389
2484
return (positive_cond_base , negative_cond_base , positive_cond_refiner , negative_cond_refiner , style , )
2390
2485
# encode the style prompt
2391
- positive_cond_base_new = CLIPTextEncodeSDXL . encode ( self , base_clip , 1024 , 1024 , 0 , 0 , 1024 , 1024 , pos_prompt , pos_prompt )[0 ]
2392
- negative_cond_base_new = CLIPTextEncodeSDXL . encode ( self , base_clip , 1024 , 1024 , 0 , 0 , 1024 , 1024 , neg_prompt , neg_prompt )[0 ]
2393
- positive_cond_refiner_new = CLIPTextEncodeSDXLRefiner . encode ( self , refiner_clip , 6 , 4096 , 4096 , pos_prompt )[0 ]
2394
- negative_cond_refiner_new = CLIPTextEncodeSDXLRefiner . encode ( self , refiner_clip , 2.5 , 4096 , 4096 , neg_prompt )[0 ]
2486
+ positive_cond_base_new = CLIPTextEncodeSDXL_encode ( base_clip , 1024 , 1024 , 0 , 0 , 1024 , 1024 , pos_prompt , pos_prompt )[0 ]
2487
+ negative_cond_base_new = CLIPTextEncodeSDXL_encode ( base_clip , 1024 , 1024 , 0 , 0 , 1024 , 1024 , neg_prompt , neg_prompt )[0 ]
2488
+ positive_cond_refiner_new = CLIPTextEncodeSDXLRefiner_encode ( refiner_clip , 6 , 4096 , 4096 , pos_prompt )[0 ]
2489
+ negative_cond_refiner_new = CLIPTextEncodeSDXLRefiner_encode ( refiner_clip , 2.5 , 4096 , 4096 , neg_prompt )[0 ]
2395
2490
# average the style prompt with the existing conditioning
2396
2491
positive_cond_base = ConditioningAverage .addWeighted (self , positive_cond_base_new , positive_cond_base , strength )[0 ]
2397
2492
negative_cond_base = ConditioningAverage .addWeighted (self , negative_cond_base_new , negative_cond_base , strength )[0 ]
@@ -2430,8 +2525,8 @@ def add_style(self, style, strength, positive_cond_base, negative_cond_base,
2430
2525
if style == 'none' :
2431
2526
return (positive_cond_base , negative_cond_base , style , )
2432
2527
# encode the style prompt
2433
- positive_cond_base_new = CLIPTextEncodeSDXL . encode ( self , base_clip , 1024 , 1024 , 0 , 0 , 1024 , 1024 , pos_prompt , pos_prompt )[0 ]
2434
- negative_cond_base_new = CLIPTextEncodeSDXL . encode ( self , base_clip , 1024 , 1024 , 0 , 0 , 1024 , 1024 , neg_prompt , neg_prompt )[0 ]
2528
+ positive_cond_base_new = CLIPTextEncodeSDXL_encode ( base_clip , 1024 , 1024 , 0 , 0 , 1024 , 1024 , pos_prompt , pos_prompt )[0 ]
2529
+ negative_cond_base_new = CLIPTextEncodeSDXL_encode ( base_clip , 1024 , 1024 , 0 , 0 , 1024 , 1024 , neg_prompt , neg_prompt )[0 ]
2435
2530
# average the style prompt with the existing conditioning
2436
2531
positive_cond_base = ConditioningAverage .addWeighted (self , positive_cond_base_new , positive_cond_base , strength )[0 ]
2437
2532
negative_cond_base = ConditioningAverage .addWeighted (self , negative_cond_base_new , negative_cond_base , strength )[0 ]
0 commit comments