Skip to content

Commit b9fb542

Browse files
add chroma-radiance-x0 mode (#11197)
1 parent cabc4d3 commit b9fb542

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

comfy/ldm/chroma_radiance/model.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class ChromaRadianceParams(ChromaParams):
3737
nerf_final_head_type: str
3838
# None means use the same dtype as the model.
3939
nerf_embedder_dtype: Optional[torch.dtype]
40-
40+
use_x0: bool
4141

4242
class ChromaRadiance(Chroma):
4343
"""
@@ -159,6 +159,9 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
159159
self.skip_dit = []
160160
self.lite = False
161161

162+
if params.use_x0:
163+
self.register_buffer("__x0__", torch.tensor([]))
164+
162165
@property
163166
def _nerf_final_layer(self) -> nn.Module:
164167
if self.params.nerf_final_head_type == "linear":
@@ -276,6 +279,12 @@ def radiance_get_override_params(self, overrides: dict) -> ChromaRadianceParams:
276279
params_dict |= overrides
277280
return params.__class__(**params_dict)
278281

282+
def _apply_x0_residual(self, predicted, noisy, timesteps):
283+
284+
# non zero during training to prevent 0 div
285+
eps = 0.0
286+
return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps)
287+
279288
def _forward(
280289
self,
281290
x: Tensor,
@@ -316,4 +325,11 @@ def _forward(
316325
transformer_options,
317326
attn_mask=kwargs.get("attention_mask", None),
318327
)
319-
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]
328+
329+
out = self.forward_nerf(img, img_out, params)[:, :, :h, :w]
330+
331+
# If x0 variant → v-pred, just return this instead
332+
if hasattr(self, "__x0__"):
333+
out = self._apply_x0_residual(out, img, timestep)
334+
return out
335+

comfy/model_detection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
257257
dit_config["nerf_tile_size"] = 512
258258
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
259259
dit_config["nerf_embedder_dtype"] = torch.float32
260+
if "__x0__" in state_dict_keys: # x0 pred
261+
dit_config["use_x0"] = True
260262
else:
261263
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
262264
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys

0 commit comments

Comments
 (0)