2424from tqdm import tqdm
2525
2626import folder_paths
27- from comfy_extras .nodes_clip_sdxl import CLIPTextEncodeSDXL , CLIPTextEncodeSDXLRefiner
2827import comfy_extras
29- from comfy_extras . nodes_upscale_model import UpscaleModelLoader , ImageUpscaleWithModel
30- from comfy . model_management import soft_empty_cache , free_memory , get_torch_device , current_loaded_models , load_model_gpu
28+ from spandrel import ModelLoader , ImageModelDescriptor
29+ from comfy import model_management
3130from nodes import LoraLoader , ConditioningAverage , common_ksampler , ImageScale , ImageScaleBy , VAEEncode , VAEDecode
3231import comfy .utils
33- from comfy_extras .chainner_models import model_loading
34- from comfy_extras .nodes_custom_sampler import Noise_EmptyNoise , Noise_RandomNoise
35- from comfy import model_management , model_base
32+
33+ # region COMFY_EXTRAS
34+ def CLIPTextEncodeSDXLRefiner_encode (clip , ascore , width , height , text ):
35+ tokens = clip .tokenize (text )
36+ return (clip .encode_from_tokens_scheduled (tokens , add_dict = {"aesthetic_score" : ascore , "width" : width , "height" : height }), )
37+
38+ def CLIPTextEncodeSDXL_encode (clip , width , height , crop_w , crop_h , target_width , target_height , text_g , text_l ):
39+ tokens = clip .tokenize (text_g )
40+ tokens ["l" ] = clip .tokenize (text_l )["l" ]
41+ if len (tokens ["l" ]) != len (tokens ["g" ]):
42+ empty = clip .tokenize ("" )
43+ while len (tokens ["l" ]) < len (tokens ["g" ]):
44+ tokens ["l" ] += empty ["l" ]
45+ while len (tokens ["l" ]) > len (tokens ["g" ]):
46+ tokens ["g" ] += empty ["g" ]
47+ return (clip .encode_from_tokens_scheduled (tokens , add_dict = {"width" : width , "height" : height , "crop_w" : crop_w , "crop_h" : crop_h , "target_width" : target_width , "target_height" : target_height }), )
48+
49+ class UpscaleModelLoader :
50+ @classmethod
51+ def INPUT_TYPES (s ):
52+ return {"required" : { "model_name" : (folder_paths .get_filename_list ("upscale_models" ), ),
53+ }}
54+ RETURN_TYPES = ("UPSCALE_MODEL" ,)
55+ FUNCTION = "load_model"
56+
57+ CATEGORY = "loaders"
58+
59+ def load_model (self , model_name ):
60+ model_path = folder_paths .get_full_path_or_raise ("upscale_models" , model_name )
61+ sd = comfy .utils .load_torch_file (model_path , safe_load = True )
62+ if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd :
63+ sd = comfy .utils .state_dict_prefix_replace (sd , {"module." :"" })
64+ out = ModelLoader ().load_from_state_dict (sd ).eval ()
65+
66+ if not isinstance (out , ImageModelDescriptor ):
67+ raise Exception ("Upscale model must be a single-image model." )
68+
69+ return (out , )
70+
71+ class ImageUpscaleWithModel :
72+ @classmethod
73+ def INPUT_TYPES (s ):
74+ return {"required" : { "upscale_model" : ("UPSCALE_MODEL" ,),
75+ "image" : ("IMAGE" ,),
76+ }}
77+ RETURN_TYPES = ("IMAGE" ,)
78+ FUNCTION = "upscale"
79+
80+ CATEGORY = "image/upscaling"
81+
82+ def upscale (self , upscale_model , image ):
83+ device = model_management .get_torch_device ()
84+
85+ memory_required = model_management .module_size (upscale_model .model )
86+ memory_required += (512 * 512 * 3 ) * image .element_size () * max (upscale_model .scale , 1.0 ) * 384.0 #The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate
87+ memory_required += image .nelement () * image .element_size ()
88+ model_management .free_memory (memory_required , device )
89+
90+ upscale_model .to (device )
91+ in_img = image .movedim (- 1 ,- 3 ).to (device )
92+
93+ tile = 512
94+ overlap = 32
95+
96+ oom = True
97+ while oom :
98+ try :
99+ steps = in_img .shape [0 ] * comfy .utils .get_tiled_scale_steps (in_img .shape [3 ], in_img .shape [2 ], tile_x = tile , tile_y = tile , overlap = overlap )
100+ pbar = comfy .utils .ProgressBar (steps )
101+ s = comfy .utils .tiled_scale (in_img , lambda a : upscale_model (a ), tile_x = tile , tile_y = tile , overlap = overlap , upscale_amount = upscale_model .scale , pbar = pbar )
102+ oom = False
103+ except model_management .OOM_EXCEPTION as e :
104+ tile //= 2
105+ if tile < 128 :
106+ raise e
107+
108+ upscale_model .to ("cpu" )
109+ s = torch .clamp (s .movedim (- 3 ,- 1 ), min = 0 , max = 1.0 )
110+ return (s ,)
111+
112+ class Noise_EmptyNoise :
113+ def __init__ (self ):
114+ self .seed = 0
115+
116+ def generate_noise (self , input_latent ):
117+ latent_image = input_latent ["samples" ]
118+ return torch .zeros (latent_image .shape , dtype = latent_image .dtype , layout = latent_image .layout , device = "cpu" )
119+
120+ class Noise_RandomNoise :
121+ def __init__ (self , seed ):
122+ self .seed = seed
123+
124+ def generate_noise (self , input_latent ):
125+ latent_image = input_latent ["samples" ]
126+ batch_inds = input_latent ["batch_index" ] if "batch_index" in input_latent else None
127+ return comfy .sample .prepare_noise (latent_image , self .seed , batch_inds )
128+ # end region
36129
37130def calculate_file_hash (file_path ):
38131 # open the file in binary mode
@@ -1816,6 +1909,8 @@ def INPUT_TYPES(s):
18161909 FUNCTION = 'start'
18171910 CATEGORY = 'Mikey'
18181911
1912+
1913+
18191914 def start (self , clip_base , clip_refiner , positive_prompt , negative_prompt , style , ratio_selected , batch_size , seed ):
18201915 """ get output from PromptWithStyle.start """
18211916 (latent ,
@@ -1835,10 +1930,10 @@ def start(self, clip_base, clip_refiner, positive_prompt, negative_prompt, style
18351930 # 'Target Width:', target_width, 'Target Height:', target_height,
18361931 # 'Refiner Width:', refiner_width, 'Refiner Height:', refiner_height)
18371932 # encode text
1838- sdxl_pos_cond = CLIPTextEncodeSDXL . encode ( self , clip_base , width , height , 0 , 0 , target_width , target_height , pos_prompt , pos_style )[0 ]
1839- sdxl_neg_cond = CLIPTextEncodeSDXL . encode ( self , clip_base , width , height , 0 , 0 , target_width , target_height , neg_prompt , neg_style )[0 ]
1840- refiner_pos_cond = CLIPTextEncodeSDXLRefiner . encode ( self , clip_refiner , 6 , refiner_width , refiner_height , pos_prompt )[0 ]
1841- refiner_neg_cond = CLIPTextEncodeSDXLRefiner . encode ( self , clip_refiner , 2.5 , refiner_width , refiner_height , neg_prompt )[0 ]
1933+ sdxl_pos_cond = CLIPTextEncodeSDXL_encode ( clip_base , width , height , 0 , 0 , target_width , target_height , pos_prompt , pos_style )[0 ]
1934+ sdxl_neg_cond = CLIPTextEncodeSDXL_encode ( clip_base , width , height , 0 , 0 , target_width , target_height , neg_prompt , neg_style )[0 ]
1935+ refiner_pos_cond = CLIPTextEncodeSDXLRefiner_encode ( clip_refiner , 6 , refiner_width , refiner_height , pos_prompt )[0 ]
1936+ refiner_neg_cond = CLIPTextEncodeSDXLRefiner_encode ( clip_refiner , 2.5 , refiner_width , refiner_height , neg_prompt )[0 ]
18421937 # return
18431938 return (latent ,
18441939 sdxl_pos_cond , sdxl_neg_cond ,
@@ -2117,10 +2212,10 @@ def start(self, base_model, clip_base, clip_refiner, positive_prompt, negative_p
21172212 # encode text
21182213 add_metadata_to_dict (prompt_with_style , style = style_ , clip_g_positive = pos_prompt , clip_l_positive = pos_style_ )
21192214 add_metadata_to_dict (prompt_with_style , clip_g_negative = neg_prompt , clip_l_negative = neg_style_ )
2120- sdxl_pos_cond = CLIPTextEncodeSDXL . encode ( self , clip_base_pos , width , height , 0 , 0 , target_width , target_height , pos_prompt_ , pos_style_ )[0 ]
2121- sdxl_neg_cond = CLIPTextEncodeSDXL . encode ( self , clip_base_neg , width , height , 0 , 0 , target_width , target_height , neg_prompt_ , neg_style_ )[0 ]
2122- refiner_pos_cond = CLIPTextEncodeSDXLRefiner . encode ( self , clip_refiner , 6 , refiner_width , refiner_height , pos_prompt_ )[0 ]
2123- refiner_neg_cond = CLIPTextEncodeSDXLRefiner . encode ( self , clip_refiner , 2.5 , refiner_width , refiner_height , neg_prompt_ )[0 ]
2215+ sdxl_pos_cond = CLIPTextEncodeSDXL_encode ( clip_base_pos , width , height , 0 , 0 , target_width , target_height , pos_prompt_ , pos_style_ )[0 ]
2216+ sdxl_neg_cond = CLIPTextEncodeSDXL_encode ( clip_base_neg , width , height , 0 , 0 , target_width , target_height , neg_prompt_ , neg_style_ )[0 ]
2217+ refiner_pos_cond = CLIPTextEncodeSDXLRefiner_encode ( clip_refiner , 6 , refiner_width , refiner_height , pos_prompt_ )[0 ]
2218+ refiner_neg_cond = CLIPTextEncodeSDXLRefiner_encode ( clip_refiner , 2.5 , refiner_width , refiner_height , neg_prompt_ )[0 ]
21242219 #prompt.get(str(unique_id))['inputs']['output_positive_prompt'] = pos_prompt_
21252220 #prompt.get(str(unique_id))['inputs']['output_negative_prompt'] = neg_prompt_
21262221 #prompt.get(str(unique_id))['inputs']['output_latent_width'] = width
@@ -2166,10 +2261,10 @@ def start(self, base_model, clip_base, clip_refiner, positive_prompt, negative_p
21662261 # encode text
21672262 add_metadata_to_dict (prompt_with_style , style = style_ , clip_g_positive = pos_prompt_ , clip_l_positive = pos_style_ )
21682263 add_metadata_to_dict (prompt_with_style , clip_g_negative = neg_prompt_ , clip_l_negative = neg_style_ )
2169- base_pos_conds .append (CLIPTextEncodeSDXL . encode ( self , clip_base_pos , width_ , height_ , 0 , 0 , target_width , target_height , pos_prompt_ , pos_style_ )[0 ])
2170- base_neg_conds .append (CLIPTextEncodeSDXL . encode ( self , clip_base_neg , width_ , height_ , 0 , 0 , target_width , target_height , neg_prompt_ , neg_style_ )[0 ])
2171- refiner_pos_conds .append (CLIPTextEncodeSDXLRefiner . encode ( self , clip_refiner , 6 , refiner_width_ , refiner_height_ , pos_prompt_ )[0 ])
2172- refiner_neg_conds .append (CLIPTextEncodeSDXLRefiner . encode ( self , clip_refiner , 2.5 , refiner_width_ , refiner_height_ , neg_prompt_ )[0 ])
2264+ base_pos_conds .append (CLIPTextEncodeSDXL_encode ( clip_base_pos , width_ , height_ , 0 , 0 , target_width , target_height , pos_prompt_ , pos_style_ )[0 ])
2265+ base_neg_conds .append (CLIPTextEncodeSDXL_encode ( clip_base_neg , width_ , height_ , 0 , 0 , target_width , target_height , neg_prompt_ , neg_style_ )[0 ])
2266+ refiner_pos_conds .append (CLIPTextEncodeSDXLRefiner_encode ( clip_refiner , 6 , refiner_width_ , refiner_height_ , pos_prompt_ )[0 ])
2267+ refiner_neg_conds .append (CLIPTextEncodeSDXLRefiner_encode ( clip_refiner , 2.5 , refiner_width_ , refiner_height_ , neg_prompt_ )[0 ])
21732268 # if none of the styles matched we will get an empty list so we need to check for that again
21742269 if len (base_pos_conds ) == 0 :
21752270 style_ = 'none'
@@ -2180,10 +2275,10 @@ def start(self, base_model, clip_base, clip_refiner, positive_prompt, negative_p
21802275 # encode text
21812276 add_metadata_to_dict (prompt_with_style , style = style_ , clip_g_positive = pos_prompt_ , clip_l_positive = pos_style_ )
21822277 add_metadata_to_dict (prompt_with_style , clip_g_negative = neg_prompt_ , clip_l_negative = neg_style_ )
2183- sdxl_pos_cond = CLIPTextEncodeSDXL . encode ( self , clip_base_pos , width , height , 0 , 0 , target_width , target_height , pos_prompt_ , pos_style_ )[0 ]
2184- sdxl_neg_cond = CLIPTextEncodeSDXL . encode ( self , clip_base_neg , width , height , 0 , 0 , target_width , target_height , neg_prompt_ , neg_style_ )[0 ]
2185- refiner_pos_cond = CLIPTextEncodeSDXLRefiner . encode ( self , clip_refiner , 6 , refiner_width , refiner_height , pos_prompt_ )[0 ]
2186- refiner_neg_cond = CLIPTextEncodeSDXLRefiner . encode ( self , clip_refiner , 2.5 , refiner_width , refiner_height , neg_prompt_ )[0 ]
2278+ sdxl_pos_cond = CLIPTextEncodeSDXL_encode ( clip_base_pos , width , height , 0 , 0 , target_width , target_height , pos_prompt_ , pos_style_ )[0 ]
2279+ sdxl_neg_cond = CLIPTextEncodeSDXL_encode ( clip_base_neg , width , height , 0 , 0 , target_width , target_height , neg_prompt_ , neg_style_ )[0 ]
2280+ refiner_pos_cond = CLIPTextEncodeSDXLRefiner_encode ( clip_refiner , 6 , refiner_width , refiner_height , pos_prompt_ )[0 ]
2281+ refiner_neg_cond = CLIPTextEncodeSDXLRefiner_encode ( clip_refiner , 2.5 , refiner_width , refiner_height , neg_prompt_ )[0 ]
21872282 #prompt.get(str(unique_id))['inputs']['output_positive_prompt'] = pos_prompt_
21882283 #prompt.get(str(unique_id))['inputs']['output_negative_prompt'] = neg_prompt_
21892284 #prompt.get(str(unique_id))['inputs']['output_latent_width'] = width
@@ -2388,10 +2483,10 @@ def add_style(self, style, strength, positive_cond_base, negative_cond_base,
23882483 if style == 'none' :
23892484 return (positive_cond_base , negative_cond_base , positive_cond_refiner , negative_cond_refiner , style , )
23902485 # encode the style prompt
2391- positive_cond_base_new = CLIPTextEncodeSDXL . encode ( self , base_clip , 1024 , 1024 , 0 , 0 , 1024 , 1024 , pos_prompt , pos_prompt )[0 ]
2392- negative_cond_base_new = CLIPTextEncodeSDXL . encode ( self , base_clip , 1024 , 1024 , 0 , 0 , 1024 , 1024 , neg_prompt , neg_prompt )[0 ]
2393- positive_cond_refiner_new = CLIPTextEncodeSDXLRefiner . encode ( self , refiner_clip , 6 , 4096 , 4096 , pos_prompt )[0 ]
2394- negative_cond_refiner_new = CLIPTextEncodeSDXLRefiner . encode ( self , refiner_clip , 2.5 , 4096 , 4096 , neg_prompt )[0 ]
2486+ positive_cond_base_new = CLIPTextEncodeSDXL_encode ( base_clip , 1024 , 1024 , 0 , 0 , 1024 , 1024 , pos_prompt , pos_prompt )[0 ]
2487+ negative_cond_base_new = CLIPTextEncodeSDXL_encode ( base_clip , 1024 , 1024 , 0 , 0 , 1024 , 1024 , neg_prompt , neg_prompt )[0 ]
2488+ positive_cond_refiner_new = CLIPTextEncodeSDXLRefiner_encode ( refiner_clip , 6 , 4096 , 4096 , pos_prompt )[0 ]
2489+ negative_cond_refiner_new = CLIPTextEncodeSDXLRefiner_encode ( refiner_clip , 2.5 , 4096 , 4096 , neg_prompt )[0 ]
23952490 # average the style prompt with the existing conditioning
23962491 positive_cond_base = ConditioningAverage .addWeighted (self , positive_cond_base_new , positive_cond_base , strength )[0 ]
23972492 negative_cond_base = ConditioningAverage .addWeighted (self , negative_cond_base_new , negative_cond_base , strength )[0 ]
@@ -2430,8 +2525,8 @@ def add_style(self, style, strength, positive_cond_base, negative_cond_base,
24302525 if style == 'none' :
24312526 return (positive_cond_base , negative_cond_base , style , )
24322527 # encode the style prompt
2433- positive_cond_base_new = CLIPTextEncodeSDXL . encode ( self , base_clip , 1024 , 1024 , 0 , 0 , 1024 , 1024 , pos_prompt , pos_prompt )[0 ]
2434- negative_cond_base_new = CLIPTextEncodeSDXL . encode ( self , base_clip , 1024 , 1024 , 0 , 0 , 1024 , 1024 , neg_prompt , neg_prompt )[0 ]
2528+ positive_cond_base_new = CLIPTextEncodeSDXL_encode ( base_clip , 1024 , 1024 , 0 , 0 , 1024 , 1024 , pos_prompt , pos_prompt )[0 ]
2529+ negative_cond_base_new = CLIPTextEncodeSDXL_encode ( base_clip , 1024 , 1024 , 0 , 0 , 1024 , 1024 , neg_prompt , neg_prompt )[0 ]
24352530 # average the style prompt with the existing conditioning
24362531 positive_cond_base = ConditioningAverage .addWeighted (self , positive_cond_base_new , positive_cond_base , strength )[0 ]
24372532 negative_cond_base = ConditioningAverage .addWeighted (self , negative_cond_base_new , negative_cond_base , strength )[0 ]
0 commit comments