1919import scipy .stats
2020import numpy
2121
22+
23+ def add_area_dims (area , num_dims ):
24+ while (len (area ) // 2 ) < num_dims :
25+ area = [2147483648 ] + area [:len (area ) // 2 ] + [0 ] + area [len (area ) // 2 :]
26+ return area
27+
2228def get_area_and_mult (conds , x_in , timestep_in ):
2329 dims = tuple (x_in .shape [2 :])
2430 area = None
@@ -34,8 +40,9 @@ def get_area_and_mult(conds, x_in, timestep_in):
3440 return None
3541 if 'area' in conds :
3642 area = list (conds ['area' ])
37- while (len (area ) // 2 ) < len (dims ):
38- area = [2147483648 ] + area [:len (area ) // 2 ] + [0 ] + area [len (area ) // 2 :]
43+ area = add_area_dims (area , len (dims ))
44+ if (len (area ) // 2 ) > len (dims ):
45+ area = area [:len (dims )] + area [len (area ) // 2 :(len (area ) // 2 ) + len (dims )]
3946
4047 if 'strength' in conds :
4148 strength = conds ['strength' ]
@@ -53,7 +60,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
5360 if "mask_strength" in conds :
5461 mask_strength = conds ["mask_strength" ]
5562 mask = conds ['mask' ]
56- assert (mask .shape [1 :] == x_in .shape [2 :])
63+ assert (mask .shape [1 :] == x_in .shape [2 :])
5764
5865 mask = mask [:input_x .shape [0 ]]
5966 if area is not None :
@@ -67,16 +74,17 @@ def get_area_and_mult(conds, x_in, timestep_in):
6774 mult = mask * strength
6875
6976 if 'mask' not in conds and area is not None :
70- rr = 8
77+ fuzz = 8
7178 for i in range (len (dims )):
79+ rr = min (fuzz , mult .shape [2 + i ] // 4 )
7280 if area [len (dims ) + i ] != 0 :
7381 for t in range (rr ):
7482 m = mult .narrow (i + 2 , t , 1 )
75- m *= ((1.0 / rr ) * (t + 1 ))
83+ m *= ((1.0 / rr ) * (t + 1 ))
7684 if (area [i ] + area [len (dims ) + i ]) < x_in .shape [i + 2 ]:
7785 for t in range (rr ):
7886 m = mult .narrow (i + 2 , area [i ] - 1 - t , 1 )
79- m *= ((1.0 / rr ) * (t + 1 ))
87+ m *= ((1.0 / rr ) * (t + 1 ))
8088
8189 conditioning = {}
8290 model_conds = conds ["model_conds" ]
@@ -551,25 +559,37 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
551559 logging .warning ("WARNING: The comfy.samplers.resolve_areas_and_cond_masks function is deprecated please use the resolve_areas_and_cond_masks_multidim one instead." )
552560 return resolve_areas_and_cond_masks_multidim (conditions , [h , w ], device )
553561
554- def create_cond_with_same_area_if_none (conds , c ): #TODO: handle dim != 2
562+ def create_cond_with_same_area_if_none (conds , c ):
555563 if 'area' not in c :
556564 return
557565
566+ def area_inside (a , area_cmp ):
567+ a = add_area_dims (a , len (area_cmp ) // 2 )
568+ area_cmp = add_area_dims (area_cmp , len (a ) // 2 )
569+
570+ a_l = len (a ) // 2
571+ area_cmp_l = len (area_cmp ) // 2
572+ for i in range (min (a_l , area_cmp_l )):
573+ if a [a_l + i ] < area_cmp [area_cmp_l + i ]:
574+ return False
575+ for i in range (min (a_l , area_cmp_l )):
576+ if (a [i ] + a [a_l + i ]) > (area_cmp [i ] + area_cmp [area_cmp_l + i ]):
577+ return False
578+ return True
579+
558580 c_area = c ['area' ]
559581 smallest = None
560582 for x in conds :
561583 if 'area' in x :
562584 a = x ['area' ]
563- if c_area [2 ] >= a [2 ] and c_area [3 ] >= a [3 ]:
564- if a [0 ] + a [2 ] >= c_area [0 ] + c_area [2 ]:
565- if a [1 ] + a [3 ] >= c_area [1 ] + c_area [3 ]:
566- if smallest is None :
567- smallest = x
568- elif 'area' not in smallest :
569- smallest = x
570- else :
571- if smallest ['area' ][0 ] * smallest ['area' ][1 ] > a [0 ] * a [1 ]:
572- smallest = x
585+ if area_inside (c_area , a ):
586+ if smallest is None :
587+ smallest = x
588+ elif 'area' not in smallest :
589+ smallest = x
590+ else :
591+ if math .prod (smallest ['area' ][:len (smallest ['area' ]) // 2 ]) > math .prod (a [:len (a ) // 2 ]):
592+ smallest = x
573593 else :
574594 if smallest is None :
575595 smallest = x
0 commit comments