@@ -331,11 +331,11 @@ def INPUT_TYPES(s):
331331 "model" : "Select the BiRefNet model variant to use." ,
332332 "mask_blur" : "Specify the amount of blur to apply to the mask edges (0 for no blur, higher values for more blur)." ,
333333 "mask_offset" : "Adjust the mask boundary (positive values expand the mask, negative values shrink it)." ,
334- "background" : "Choose the background color for the final output (Alpha for transparent background)." ,
335334 "invert_output" : "Enable to invert both the image and mask output (useful for certain effects)." ,
336- "refine_foreground" : "Use Fast Foreground Colour Estimation to optimize transparent background"
335+ "refine_foreground" : "Use Fast Foreground Colour Estimation to optimize transparent background" ,
336+ "background" : "Choose background type: Alpha (transparent) or Color (custom background color)." ,
337+ "background_color" : "Choose background color (Alpha = transparent)"
337338 }
338-
339339 return {
340340 "required" : {
341341 "image" : ("IMAGE" , {"tooltip" : tooltips ["image" ]}),
@@ -344,9 +344,10 @@ def INPUT_TYPES(s):
344344 "optional" : {
345345 "mask_blur" : ("INT" , {"default" : 0 , "min" : 0 , "max" : 64 , "step" : 1 , "tooltip" : tooltips ["mask_blur" ]}),
346346 "mask_offset" : ("INT" , {"default" : 0 , "min" : - 20 , "max" : 20 , "step" : 1 , "tooltip" : tooltips ["mask_offset" ]}),
347- "background" : (["Alpha" , "black" , "white" , "gray" , "green" , "blue" , "red" ], {"default" : "Alpha" , "tooltip" : tooltips ["background" ]}),
348347 "invert_output" : ("BOOLEAN" , {"default" : False , "tooltip" : tooltips ["invert_output" ]}),
349- "refine_foreground" : ("BOOLEAN" , {"default" : False , "tooltip" : tooltips ["refine_foreground" ]})
348+ "refine_foreground" : ("BOOLEAN" , {"default" : False , "tooltip" : tooltips ["refine_foreground" ]}),
349+ "background" : (["Alpha" , "Color" ], {"default" : "Alpha" , "tooltip" : tooltips ["background" ]}),
350+ "background_color" : ("COLOR" , {"default" : "#222222" , "tooltip" : tooltips ["background_color" ]}),
350351 }
351352 }
352353
@@ -358,35 +359,16 @@ def INPUT_TYPES(s):
358359 def process_image (self , image , model , ** params ):
359360 try :
360361 model_config = MODEL_CONFIG [model ]
361-
362- # Always use model's default resolution
363362 process_res = model_config .get ("default_res" , 1024 )
364-
365- # Handle special resolution requirements
366363 if model_config .get ("force_res" , False ):
367364 base_res = 512
368365 process_res = ((process_res + base_res - 1 ) // base_res ) * base_res
369366 else :
370367 process_res = process_res // 32 * 32
371-
372368 print (f"Using { model } model with { process_res } resolution" )
373-
374369 params ["process_res" ] = process_res
375-
376370 processed_images = []
377371 processed_masks = []
378-
379- bg_colors = {
380- "Alpha" : None ,
381- "black" : (0 , 0 , 0 ),
382- "white" : (255 , 255 , 255 ),
383- "gray" : (128 , 128 , 128 ),
384- "green" : (0 , 255 , 0 ),
385- "blue" : (0 , 0 , 255 ),
386- "red" : (255 , 0 , 0 )
387- }
388-
389- # Check and download model if needed
390372 cache_status , message = self .model .check_model_cache (model )
391373 if not cache_status :
392374 print (f"Cache check: { message } " )
@@ -395,38 +377,24 @@ def process_image(self, image, model, **params):
395377 if not download_status :
396378 handle_model_error (download_message )
397379 print ("Model files downloaded successfully" )
398-
399- # Load model if needed
400380 self .model .load_model (model )
401-
402381 for img in image :
403- # Get mask from model
404382 mask = self .model .process_image (img , params )
405-
406- # Post-process mask
407383 if params ["mask_blur" ] > 0 :
408384 mask = mask .filter (ImageFilter .GaussianBlur (radius = params ["mask_blur" ]))
409-
410385 if params ["mask_offset" ] != 0 :
411386 if params ["mask_offset" ] > 0 :
412387 for _ in range (params ["mask_offset" ]):
413388 mask = mask .filter (ImageFilter .MaxFilter (3 ))
414389 else :
415390 for _ in range (- params ["mask_offset" ]):
416391 mask = mask .filter (ImageFilter .MinFilter (3 ))
417-
418392 if params ["invert_output" ]:
419393 mask = Image .fromarray (255 - np .array (mask ))
420-
421- # Convert to tensors for refine_foreground
422394 img_tensor = torch .from_numpy (np .array (tensor2pil (img ))).permute (2 , 0 , 1 ).unsqueeze (0 ) / 255.0
423395 mask_tensor = torch .from_numpy (np .array (mask )).unsqueeze (0 ).unsqueeze (0 ) / 255.0
424-
425396 if params .get ("refine_foreground" , False ):
426- refined_fg = refine_foreground (
427- img_tensor ,
428- mask_tensor
429- )
397+ refined_fg = refine_foreground (img_tensor , mask_tensor )
430398 refined_fg = tensor2pil (refined_fg [0 ].permute (1 , 2 , 0 ))
431399 orig_image = tensor2pil (img )
432400 r , g , b = refined_fg .split ()
@@ -436,28 +404,30 @@ def process_image(self, image, model, **params):
436404 orig_rgba = orig_image .convert ("RGBA" )
437405 r , g , b , _ = orig_rgba .split ()
438406 foreground = Image .merge ('RGBA' , (r , g , b , mask ))
439-
440- if params ["background" ] != "Alpha" :
441- bg_color = bg_colors [params ["background" ]]
442- bg_image = Image .new ('RGBA' , orig_image .size , (* bg_color , 255 ))
407+ if params ["background" ] == "Alpha" :
408+ processed_images .append (pil2tensor (foreground ))
409+ else :
410+ def hex_to_rgba (hex_color ):
411+ hex_color = hex_color .lstrip ('#' )
412+ if len (hex_color ) == 6 :
413+ r , g , b = int (hex_color [0 :2 ], 16 ), int (hex_color [2 :4 ], 16 ), int (hex_color [4 :6 ], 16 )
414+ a = 255
415+ elif len (hex_color ) == 8 :
416+ r , g , b , a = int (hex_color [0 :2 ], 16 ), int (hex_color [2 :4 ], 16 ), int (hex_color [4 :6 ], 16 ), int (hex_color [6 :8 ], 16 )
417+ else :
418+ raise ValueError ("Invalid color format" )
419+ return (r , g , b , a )
420+ rgba = hex_to_rgba (params ["background_color" ])
421+ bg_image = Image .new ('RGBA' , orig_image .size , rgba )
443422 composite_image = Image .alpha_composite (bg_image , foreground )
444423 processed_images .append (pil2tensor (composite_image .convert ("RGB" )))
445- else :
446- processed_images .append (pil2tensor (foreground ))
447-
448424 processed_masks .append (pil2tensor (mask ))
449-
450- # Create mask image for visualization
451425 mask_images = []
452426 for mask_tensor in processed_masks :
453- # Convert mask to RGB image format for visualization
454427 mask_image = mask_tensor .reshape ((- 1 , 1 , mask_tensor .shape [- 2 ], mask_tensor .shape [- 1 ])).movedim (1 , - 1 ).expand (- 1 , - 1 , - 1 , 3 )
455428 mask_images .append (mask_image )
456-
457429 mask_image_output = torch .cat (mask_images , dim = 0 )
458-
459430 return (torch .cat (processed_images , dim = 0 ), torch .cat (processed_masks , dim = 0 ), mask_image_output )
460-
461431 except Exception as e :
462432 handle_model_error (f"Error in image processing: { str (e )} " )
463433
0 commit comments