Skip to content

Commit cefc7b6

Browse files
authored
Remove einops dep (#2006)
We shouldn't add a new dependency for just to reshape calls. We can do this the less readable way for now.
1 parent da6db4e commit cefc7b6

File tree

2 files changed

+4
-8
lines changed

2 files changed

+4
-8
lines changed

keras_hub/src/models/flux/flux_maths.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import keras
2-
from einops import rearrange
32
from keras import ops
43

54

@@ -58,7 +57,7 @@ def call(self, pos, dim, theta):
5857
out = ops.stack(
5958
[ops.cos(out), -ops.sin(out), ops.sin(out), ops.cos(out)], axis=-1
6059
)
61-
out = rearrange(out, "... n d (i j) -> ... n d i j", i=2, j=2)
60+
out = ops.reshape(out, ops.shape(out)[:-1] + (2, 2))
6261
return ops.cast(out, dtype="float32")
6362

6463

@@ -122,9 +121,9 @@ def call(self, q, k, v, positional_encoding):
122121
x = scaled_dot_product_attention(
123122
q, k, v, dropout_p=self.dropout_p, is_causal=self.is_causal
124123
)
125-
126-
x = rearrange(x, "B H L D -> B L (H D)")
127-
return x
124+
x = ops.transpose(x, (0, 2, 1, 3))
125+
b, l, h, d = ops.shape(x)
126+
return ops.reshape(x, (b, l, h * d))
128127

129128

130129
# TODO: This is probably already implemented in several places, but is needed to ensure numeric equivalence to the original

requirements-common.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,3 @@ sentencepiece
1919
tensorflow-datasets
2020
safetensors
2121
pillow
22-
# Will be replaced once https://github.com/keras-team/keras/issues/20332
23-
# is resolved
24-
einops

0 commit comments

Comments
 (0)