3737    build_laplacian_pyramid_func  =  None 
3838
3939
40- def  project (v0 : torch .Tensor , v1 : torch .Tensor ) ->  Tuple [torch .Tensor , torch .Tensor ]:
40+ def  project (v0 : torch .Tensor , v1 : torch .Tensor ,  upcast_to_double :  bool   =   True ) ->  Tuple [torch .Tensor , torch .Tensor ]:
4141    """ 
4242    Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from 
4343    paper (Algorithm 2). 
4444    """ 
4545    # v0 shape: [B, ...] 
4646    # v1 shape: [B, ...] 
47-     dtype  =  v0 .dtype 
4847    # Assume first dim is a batch dim and all other dims are channel or "spatial" dims 
4948    all_dims_but_first  =  list (range (1 , len (v0 .shape )))
50-     v0 , v1  =  v0 .double (), v1 .double ()
49+     if  upcast_to_double :
50+         dtype  =  v0 .dtype 
51+         v0 , v1  =  v0 .double (), v1 .double ()
5152    v1  =  torch .nn .functional .normalize (v1 , dim = all_dims_but_first )
5253    v0_parallel  =  (v0  *  v1 ).sum (dim = all_dims_but_first , keepdim = True ) *  v1 
5354    v0_orthogonal  =  v0  -  v0_parallel 
54-     return  v0_parallel .to (dtype ), v0_orthogonal .to (dtype )
55+     if  upcast_to_double :
56+         v0_parallel  =  v0_parallel .to (dtype )
57+         v0_orthogonal  =  v0_orthogonal .to (dtype )
58+     return  v0_parallel , v0_orthogonal 
5559
5660
5761def  build_image_from_pyramid (pyramid : List [torch .Tensor ]) ->  torch .Tensor :
@@ -127,6 +131,9 @@ class FrequencyDecoupledGuidance(BaseGuidance):
127131            `"freq"` space (right after the CFG update, for each freq level). Note that frequency space rescaling is 
128132            speculative and may not produce expected results. If `"data"` is set, the first `guidance_rescale` value 
129133            will be used; otherwise, per-frequency-level guidance rescale values will be used if available. 
134+         upcast_to_double (`bool`, defaults to `True`): 
135+             Whether to upcast certain operations, such as the projection operation when using `parallel_weights`, to 
136+             float64 when performing guidance. This may result in better performance at the cost of increased runtime. 
130137    """ 
131138
132139    _input_predictions  =  ["pred_cond" , "pred_uncond" ]
@@ -141,6 +148,7 @@ def __init__(
141148        start : Union [float , List [float ], Tuple [float ]] =  0.0 ,
142149        stop : Union [float , List [float ], Tuple [float ]] =  1.0 ,
143150        guidance_rescale_space : str  =  "data" ,
151+         upcast_to_double : bool  =  True ,
144152    ):
145153        if  not  _CAN_USE_KORNIA :
146154            raise  ImportError (
@@ -188,6 +196,7 @@ def __init__(
188196            )
189197
190198        self .use_original_formulation  =  use_original_formulation 
199+         self .upcast_to_double  =  upcast_to_double 
191200
192201        if  isinstance (start , float ):
193202            self .guidance_start  =  [start ] *  self .levels 
@@ -244,7 +253,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
244253
245254                    # Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift) 
246255                    if  not  math .isclose (parallel_weight , 1.0 ):
247-                         shift_parallel , shift_orthogonal  =  project (shift , pred_cond_freq )
256+                         shift_parallel , shift_orthogonal  =  project (shift , pred_cond_freq ,  self . upcast_to_double )
248257                        shift  =  parallel_weight  *  shift_parallel  +  shift_orthogonal 
249258
250259                    # Apply CFG update for the current frequency level 
0 commit comments