Skip to content

Commit 5959518

Browse files
add reshape to fix use_memory_efficient_attention in flax
1 parent be4afa0 commit 5959518

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/diffusers/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ def __call__(self, hidden_states, context=None, deterministic=True):
216216
hidden_states = jax_memory_efficient_attention(
217217
query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
218218
)
219-
220219
hidden_states = hidden_states.transpose(1, 0, 2)
220+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
221221
else:
222222
# compute attentions
223223
if self.split_head_dim:

0 commit comments

Comments
 (0)