22
33import math
44from einops import rearrange
5- import random
5+ # Use torch rng for consistency across generations
6+ from torch import randint
67
7- def random_divisor (value : int , min_value : int , / , max_options : int = 1 , counter = 0 ) -> int :
8+ def random_divisor (value : int , min_value : int , / , max_options : int = 1 ) -> int :
89 min_value = min (min_value , value )
910
1011 # All big divisors of value (inclusive)
1112 divisors = [i for i in range (min_value , value + 1 ) if value % i == 0 ]
1213
1314 ns = [value // i for i in divisors [:max_options ]] # has at least 1 element
1415
15- random .seed (counter )
16- idx = random .randint (0 , len (ns ) - 1 )
16+ idx = randint (low = 0 , high = len (ns ) - 1 , size = (1 ,)).item ()
1717
1818 return ns [idx ]
1919
@@ -42,7 +42,6 @@ def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
4242
4343 latent_tile_size = max (32 , tile_size ) // 8
4444 self .temp = None
45- self .counter = 1
4645
4746 def hypertile_in (q , k , v , extra_options ):
4847 if q .shape [- 1 ] in apply_to :
@@ -53,10 +52,8 @@ def hypertile_in(q, k, v, extra_options):
5352 h , w = round (math .sqrt (hw * aspect_ratio )), round (math .sqrt (hw / aspect_ratio ))
5453
5554 factor = 2 ** ((q .shape [- 1 ] // model_channels ) - 1 ) if scale_depth else 1
56- nh = random_divisor (h , latent_tile_size * factor , swap_size , self .counter )
57- self .counter += 1
58- nw = random_divisor (w , latent_tile_size * factor , swap_size , self .counter )
59- self .counter += 1
55+ nh = random_divisor (h , latent_tile_size * factor , swap_size )
56+ nw = random_divisor (w , latent_tile_size * factor , swap_size )
6057
6158 if nh * nw > 1 :
6259 q = rearrange (q , "b (nh h nw w) c -> (b nh nw) (h w) c" , h = h // nh , w = w // nw , nh = nh , nw = nw )
0 commit comments