@@ -37,7 +37,7 @@ class ChromaRadianceParams(ChromaParams):
3737 nerf_final_head_type : str
3838 # None means use the same dtype as the model.
3939 nerf_embedder_dtype : Optional [torch .dtype ]
40-
40+ use_x0 : bool
4141
4242class ChromaRadiance (Chroma ):
4343 """
@@ -159,6 +159,9 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
159159 self .skip_dit = []
160160 self .lite = False
161161
162+ if params .use_x0 :
163+ self .register_buffer ("__x0__" , torch .tensor ([]))
164+
162165 @property
163166 def _nerf_final_layer (self ) -> nn .Module :
164167 if self .params .nerf_final_head_type == "linear" :
@@ -276,6 +279,12 @@ def radiance_get_override_params(self, overrides: dict) -> ChromaRadianceParams:
276279 params_dict |= overrides
277280 return params .__class__ (** params_dict )
278281
282+ def _apply_x0_residual (self , predicted , noisy , timesteps ):
283+
284+ # non zero during training to prevent 0 div
285+ eps = 0.0
286+ return (noisy - predicted ) / (timesteps .view (- 1 ,1 ,1 ,1 ) + eps )
287+
279288 def _forward (
280289 self ,
281290 x : Tensor ,
@@ -316,4 +325,11 @@ def _forward(
316325 transformer_options ,
317326 attn_mask = kwargs .get ("attention_mask" , None ),
318327 )
319- return self .forward_nerf (img , img_out , params )[:, :, :h , :w ]
328+
329+ out = self .forward_nerf (img , img_out , params )[:, :, :h , :w ]
330+
331+ # If x0 variant → v-pred, just return this instead
332+ if hasattr (self , "__x0__" ):
333+ out = self ._apply_x0_residual (out , img , timestep )
334+ return out
335+
0 commit comments