1111import torch .nn .functional as F
1212from torch .amp import autocast
1313
14- from einops import rearrange , reduce , repeat
14+ from einops import rearrange , reduce , repeat , pack , unpack
1515from einops .layers .torch import Rearrange
1616
1717from tqdm .auto import tqdm
@@ -54,6 +54,15 @@ def convert_image_to_fn(img_type, image):
5454 return image .convert (img_type )
5555 return image
5656
57+ def pack_one_with_inverse (x , pattern ):
58+ packed , packed_shape = pack ([x ], pattern )
59+
60+ def inverse (x , inverse_pattern = None ):
61+ inverse_pattern = default (inverse_pattern , pattern )
62+ return unpack (x , packed_shape , inverse_pattern )[0 ]
63+
64+ return packed , inverse
65+
5766# normalization functions
5867
5968def normalize_to_neg_one_to_one (img ):
@@ -75,6 +84,19 @@ def prob_mask_like(shape, prob, device):
7584 else :
7685 return torch .zeros (shape , device = device ).float ().uniform_ (0 , 1 ) < prob
7786
87+ def project (x , y ):
88+ x , inverse = pack_one_with_inverse (x , 'b *' )
89+ y , _ = pack_one_with_inverse (y , 'b *' )
90+
91+ dtype = x .dtype
92+ x , y = x .double (), y .double ()
93+ unit = F .normalize (y , dim = - 1 )
94+
95+ parallel = (x * unit ).sum (dim = - 1 , keepdim = True ) * unit
96+ orthogonal = x - parallel
97+
98+ return inverse (parallel ).to (dtype ), inverse (orthogonal ).to (dtype )
99+
78100# small helper modules
79101
80102class Residual (nn .Module ):
@@ -357,6 +379,8 @@ def forward_with_cond_scale(
357379 * args ,
358380 cond_scale = 1. ,
359381 rescaled_phi = 0. ,
382+ remove_parallel_component = True ,
383+ keep_parallel_frac = 0. ,
360384 ** kwargs
361385 ):
362386 logits = self .forward (* args , cond_drop_prob = 0. , ** kwargs )
@@ -365,7 +389,13 @@ def forward_with_cond_scale(
365389 return logits
366390
367391 null_logits = self .forward (* args , cond_drop_prob = 1. , ** kwargs )
368- scaled_logits = null_logits + (logits - null_logits ) * cond_scale
392+ update = logits - null_logits
393+
394+ if remove_parallel_component :
395+ parallel , orthog = project (update , logits )
396+ update = orthog + parallel * keep_parallel_frac
397+
398+ scaled_logits = logits + update * (cond_scale - 1. )
369399
370400 if rescaled_phi == 0. :
371401 return scaled_logits , null_logits
0 commit comments