|
1 | 1 | import keras |
2 | | -from einops import rearrange |
3 | 2 | from keras import ops |
4 | 3 |
|
5 | 4 |
|
@@ -58,7 +57,7 @@ def call(self, pos, dim, theta): |
58 | 57 | out = ops.stack( |
59 | 58 | [ops.cos(out), -ops.sin(out), ops.sin(out), ops.cos(out)], axis=-1 |
60 | 59 | ) |
61 | | - out = rearrange(out, "... n d (i j) -> ... n d i j", i=2, j=2) |
| 60 | + out = ops.reshape(out, ops.shape(out)[:-1] + (2, 2)) |
62 | 61 | return ops.cast(out, dtype="float32") |
63 | 62 |
|
64 | 63 |
|
@@ -122,9 +121,9 @@ def call(self, q, k, v, positional_encoding): |
122 | 121 | x = scaled_dot_product_attention( |
123 | 122 | q, k, v, dropout_p=self.dropout_p, is_causal=self.is_causal |
124 | 123 | ) |
125 | | - |
126 | | - x = rearrange(x, "B H L D -> B L (H D)") |
127 | | - return x |
| 124 | + x = ops.transpose(x, (0, 2, 1, 3)) |
| 125 | + b, l, h, d = ops.shape(x) |
| 126 | + return ops.reshape(x, (b, l, h * d)) |
128 | 127 |
|
129 | 128 |
|
130 | 129 | # TODO: This is probably already implemented in several places, but is needed to ensure numeric equivalence to the original |
|
0 commit comments