Skip to content

Commit acbc6a5

Browse files
committed
up
1 parent 81440fd commit acbc6a5

File tree

2 files changed

+6
-12
lines changed

2 files changed

+6
-12
lines changed

src/diffusers/models/transformers/transformer_lumina2.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,20 +241,22 @@ def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300,
241241

242242
def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]:
243243
freqs_cis = []
244-
# Use float32 for MPS compatibility
245-
dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
246244
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
247-
emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=dtype)
245+
emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=torch.float64)
248246
freqs_cis.append(emb)
249247
return freqs_cis
250248

251249
def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor:
250+
device = ids.device
251+
if ids.device.type == "mps":
252+
ids = ids.to("cpu")
253+
252254
result = []
253255
for i in range(len(self.axes_dim)):
254256
freqs = self.freqs_cis[i].to(ids.device)
255257
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
256258
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
257-
return torch.cat(result, dim=-1)
259+
return torch.cat(result, dim=-1).to(device)
258260

259261
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
260262
batch_size = len(hidden_states)

src/diffusers/pipelines/lumina2/pipeline_lumina2.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
from ...models.transformers.transformer_lumina2 import Lumina2Transformer2DModel
2525
from ...schedulers import FlowMatchEulerDiscreteScheduler
2626
from ...utils import (
27-
is_bs4_available,
28-
is_ftfy_available,
2927
is_torch_xla_available,
3028
logging,
3129
replace_example_docstring,
@@ -44,12 +42,6 @@
4442
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4543

4644

47-
if is_bs4_available():
48-
pass
49-
50-
if is_ftfy_available():
51-
pass
52-
5345
EXAMPLE_DOC_STRING = """
5446
Examples:
5547
```py

0 commit comments

Comments
 (0)