Skip to content

Commit 4938f55

Browse files
committed
more fixes
1 parent 08beacd commit 4938f55

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

flaxdiff/models/simple_vit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,8 @@ def setup(self):
295295
)
296296

297297
self.time_embed = nn.Sequential([
298-
FourierEmbedding(features=self.emb_features, dtype=jnp.float32),
299-
TimeProjection(features=self.emb_features * self.mlp_ratio, dtype=self.dtype, precision=self.precision),
298+
FourierEmbedding(features=self.emb_features),
299+
TimeProjection(features=self.emb_features * self.mlp_ratio),
300300
nn.Dense(features=self.emb_features, dtype=self.dtype, precision=self.precision)
301301
], name="time_embed")
302302

0 commit comments

Comments
 (0)