Skip to content

Commit bae257d

Browse files
authored
Merge branch 'main' into hunyuan-video
2 parents 59c8552 + cef0e36 commit bae257d

File tree

2 files changed

+14
-15
lines changed

2 files changed

+14
-15
lines changed

examples/community/pipeline_flux_rf_inversion.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,8 @@ def __call__(
648648
height: Optional[int] = None,
649649
width: Optional[int] = None,
650650
eta: float = 1.0,
651+
decay_eta: Optional[bool] = False,
652+
eta_decay_power: Optional[float] = 1.0,
651653
strength: float = 1.0,
652654
start_timestep: float = 0,
653655
stop_timestep: float = 0.25,
@@ -880,12 +882,9 @@ def __call__(
880882
v_t = -noise_pred
881883
v_t_cond = (y_0 - latents) / (1 - t_i)
882884
eta_t = eta if start_timestep <= i < stop_timestep else 0.0
883-
if start_timestep <= i < stop_timestep:
884-
# controlled vector field
885-
v_hat_t = v_t + eta * (v_t_cond - v_t)
886-
887-
else:
888-
v_hat_t = v_t
885+
if decay_eta:
886+
eta_t = eta_t * (1 - i / num_inference_steps) ** eta_decay_power # Decay eta over the loop
887+
v_hat_t = v_t + eta_t * (v_t_cond - v_t)
889888

890889
# SDE Eq: 17 from https://arxiv.org/pdf/2410.10792
891890
latents = latents + v_hat_t * (sigmas[i] - sigmas[i + 1])

src/diffusers/models/activations.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch import nn
1919

2020
from ..utils import deprecate
21-
from ..utils.import_utils import is_torch_npu_available
21+
from ..utils.import_utils import is_torch_npu_available, is_torch_version
2222

2323

2424
if is_torch_npu_available():
@@ -79,10 +79,10 @@ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: b
7979
self.approximate = approximate
8080

8181
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
82-
if gate.device.type != "mps":
83-
return F.gelu(gate, approximate=self.approximate)
84-
# mps: gelu is not implemented for float16
85-
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
82+
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
83+
# fp16 gelu not supported on mps before torch 2.0
84+
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
85+
return F.gelu(gate, approximate=self.approximate)
8686

8787
def forward(self, hidden_states):
8888
hidden_states = self.proj(hidden_states)
@@ -105,10 +105,10 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
105105
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
106106

107107
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
108-
if gate.device.type != "mps":
109-
return F.gelu(gate)
110-
# mps: gelu is not implemented for float16
111-
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
108+
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
109+
# fp16 gelu not supported on mps before torch 2.0
110+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
111+
return F.gelu(gate)
112112

113113
def forward(self, hidden_states, *args, **kwargs):
114114
if len(args) > 0 or kwargs.get("scale", None) is not None:

0 commit comments

Comments
 (0)