Skip to content

Commit 7a991da

Browse files
committed
fix: fixed conv layer stuff
1 parent 82e962b commit 7a991da

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

flaxdiff/models/simple_vit.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,19 +205,22 @@ def setup(self):
205205
if self.add_residualblock_output:
206206
# Define these layers only if needed
207207
self.final_conv1 = ConvLayer(
208+
"conv",
208209
features=64, kernel_size=(3, 3), strides=(1, 1),
209210
dtype=self.dtype, precision=self.precision, name="final_conv1"
210211
)
211212
self.final_norm_conv = self.norm_factory(
212213
name="final_norm_conv") # Use factory
213214
self.final_conv2 = ConvLayer(
215+
"conv",
214216
features=self.output_channels, kernel_size=(3, 3), strides=(1, 1),
215217
dtype=jnp.float32, # Often good to have final conv output float32
216218
precision=self.precision, name="final_conv2"
217219
)
218220
else:
219221
# Final conv to map features to output channels directly after unpatchify
220222
self.final_conv_direct = ConvLayer(
223+
"conv",
221224
# Use 1x1 conv
222225
features=self.output_channels, kernel_size=(1, 1), strides=(1, 1),
223226
dtype=jnp.float32, # Output float32

0 commit comments

Comments
 (0)