Skip to content

Commit 3d43b89

Browse files
authored
Rope moved out of attention class to be method (#176)
Signed-off-by: Kunjan patel <[email protected]>
1 parent f68692a commit 3d43b89

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ....configuration_utils import ConfigMixin, flax_register_to_config
2525
from ...modeling_flax_utils import FlaxModelMixin
2626
from ...normalization_flax import AdaLayerNormZeroSingle, AdaLayerNormContinuous, AdaLayerNormZero
27-
from ...attention_flax import FlaxFluxAttention
27+
from ...attention_flax import FlaxFluxAttention, apply_rope
2828
from ...embeddings_flax import (FluxPosEmbed, CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings)
2929
from .... import common_types
3030
from ....common_types import BlockSizes
@@ -131,7 +131,7 @@ def __call__(self, hidden_states, temb, image_rotary_emb=None):
131131
# since this function returns image_rotary_emb and passes it between layers,
132132
# we do not want to modify it
133133
image_rotary_emb_reordered = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2)
134-
q, k = self.attn.apply_rope(q, k, image_rotary_emb_reordered)
134+
q, k = apply_rope(q, k, image_rotary_emb_reordered)
135135

136136
q = q.transpose(0, 2, 1, 3).reshape(q.shape[0], q.shape[2], -1)
137137
k = k.transpose(0, 2, 1, 3).reshape(k.shape[0], k.shape[2], -1)

src/maxdiffusion/train_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222

2323

2424
def get_first_step(state):
25-
with jax.spmd_mode("allow_all"):
26-
return int(state.step)
25+
return int(state.step)
2726

2827

2928
def load_next_batch(train_iter, example_batch, config):

0 commit comments

Comments
 (0)