Skip to content

Commit 9d2446a

Browse files
committed
fix: do not pass dtype to fourier embedding
1 parent fb9e4b2 commit 9d2446a

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

flaxdiff/models/simple_vit.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,7 @@ def setup(self):
127127

128128
# --- Conditioning ---
129129
self.time_embed = nn.Sequential([
130-
# Fourier often uses float32
131-
FourierEmbedding(features=self.emb_features, dtype=jnp.float32),
130+
FourierEmbedding(features=self.emb_features),
132131
TimeProjection(features=self.emb_features,
133132
dtype=self.dtype, precision=self.precision)
134133
], name="time_embed")

0 commit comments

Comments
 (0)