2020
2121from . import mat
2222from .util import (
23+ BlurKernel ,
24+ mask_blur ,
2325 gaussian_blur ,
2426 binary_erosion ,
2527 binary_dilation ,
@@ -126,6 +128,7 @@ def INPUT_TYPES(s):
126128
127129 def load (self , head : str , patch : str ):
128130 head_file = folder_paths .get_full_path ("inpaint" , head )
131+ assert head_file is not None , f"Inpaint head file not found in inpaint folder: { head } "
129132 inpaint_head_model = InpaintHead ()
130133 sd = torch .load (head_file , map_location = "cpu" , weights_only = True )
131134 inpaint_head_model .load_state_dict (sd )
@@ -486,19 +489,20 @@ def INPUT_TYPES(cls):
486489 "mask" : ("MASK" ,),
487490 "grow" : ("INT" , {"default" : 16 , "min" : 0 , "max" : 8096 , "step" : 1 }),
488491 "blur" : ("INT" , {"default" : 7 , "min" : 0 , "max" : 8096 , "step" : 1 }),
492+ "blur_type" : (["box" , "linear" , "gaussian" ], {"default" : "gaussian" }),
489493 }
490494 }
491495
492496 RETURN_TYPES = ("MASK" ,)
493497 CATEGORY = "inpaint"
494498 FUNCTION = "expand"
495499
496- def expand (self , mask : Tensor , grow : int , blur : int ):
500+ def expand (self , mask : Tensor , grow : int , blur : int , blur_type : str ):
497501 mask = mask_unsqueeze (mask )
498502 if grow > 0 :
499503 mask = binary_dilation (mask , grow )
500504 if blur > 0 :
501- mask = gaussian_blur (mask , make_odd (blur ))
505+ mask = mask_blur (mask , make_odd (blur ), BlurKernel [ blur_type ] )
502506 return (mask .squeeze (1 ),)
503507
504508
@@ -510,17 +514,18 @@ def INPUT_TYPES(cls):
510514 "mask" : ("MASK" ,),
511515 "shrink" : ("INT" , {"default" : 1 , "min" : 0 , "max" : 8096 , "step" : 1 }),
512516 "blur" : ("INT" , {"default" : 0 , "min" : 0 , "max" : 8096 , "step" : 1 }),
517+ "blur_type" : (["box" , "linear" , "gaussian" ], {"default" : "gaussian" }),
513518 }
514519 }
515520
516521 RETURN_TYPES = ("MASK" ,)
517522 CATEGORY = "inpaint"
518523 FUNCTION = "shrink"
519524
520- def shrink (self , mask : Tensor , shrink : int , blur : int ):
525+ def shrink (self , mask : Tensor , shrink : int , blur : int , blur_type : str ):
521526 mask = mask_unsqueeze (mask )
522527 if shrink > 0 :
523528 mask = binary_erosion (mask , shrink )
524529 if blur > 0 :
525- mask = gaussian_blur (mask , make_odd (blur ))
530+ mask = mask_blur (mask , make_odd (blur ), BlurKernel [ blur_type ] )
526531 return (mask .squeeze (1 ),)
0 commit comments