Skip to content

Commit 7b2c4ef

Browse files
committed
Update
1 parent 1fc065d commit 7b2c4ef

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

examples/slm_vision_transformer_hybrid_ex.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ namespace dlib
7979
This demonstrates a modern, clean ViT implementation using:
8080
- patch_embeddings: splits image into patches + learned projection
8181
- canonical_transformer: Dlib's transformer with RoPE positioning
82-
- Standard Dlib layers: fc, dropout, avg_pool_everything, ...
82+
- Standard Dlib layers: fc, dropout, ...
8383
8484
Architecture summary:
8585
Input (32x32 RGB) => Patches (4x4) => Embeddings (192-dim)
@@ -100,21 +100,21 @@ namespace dlib
100100
static constexpr long EMBEDDING_DIM = embedding_dim;
101101
static constexpr long PATCH_SIZE = 4; // 32/4 = 8x8 = 64 patches
102102
static constexpr long NUM_PATCHES = 64; // (32/4)^2
103+
static constexpr long DONT_USE_ClASS_TOKEN = 0;
104+
static constexpr long DONT_USE_POSITION_EMBEDDINGS = 0;
103105

104106
// Backbone: patch embeddings => transformer => pooling
105107
// Returns: (batch, embedding_dim) feature vectors
106108
template <template <typename> class DO, typename INPUT>
107-
using backbone_training =
108-
avg_pool_everything<
109+
using backbone_training = rms_norm<
109110
canonical_transformer::transformer_stack<NUM_LAYERS, gelu, DO, EMBEDDING_DIM, NUM_HEADS,
110-
patch_embeddings<PATCH_SIZE, EMBEDDING_DIM, 1, 0, // cls=1, no pos_emb (use RoPE)
111+
patch_embeddings<PATCH_SIZE, EMBEDDING_DIM, DONT_USE_ClASS_TOKEN, DONT_USE_POSITION_EMBEDDINGS,
111112
INPUT>>>;
112113

113114
template <typename INPUT>
114-
using backbone_inference =
115-
avg_pool_everything<
115+
using backbone_inference = rms_norm<
116116
canonical_transformer::transformer_stack<NUM_LAYERS, gelu, multiply, EMBEDDING_DIM, NUM_HEADS,
117-
patch_embeddings<PATCH_SIZE, EMBEDDING_DIM, 1, 0,
117+
patch_embeddings<PATCH_SIZE, EMBEDDING_DIM, DONT_USE_ClASS_TOKEN, DONT_USE_POSITION_EMBEDDINGS,
118118
INPUT>>>;
119119

120120
static std::string describe() {
@@ -244,7 +244,7 @@ void train_ssl(
244244
cout << "Training without labels - Learning representations from augmentations\n" << endl;
245245

246246
model::ssl_train net((loss_barlow_twins_(lambda)));
247-
dnn_trainer<model::ssl_train, adamw> trainer(net, adamw(0.04, 0.9, 0.999));
247+
dnn_trainer<model::ssl_train, adamw> trainer(net, adamw(0.01, 0.9, 0.999));
248248
trainer.set_learning_rate(learning_rate);
249249
trainer.set_min_learning_rate(min_learning_rate);
250250
trainer.set_mini_batch_size(batch_size);
@@ -305,7 +305,7 @@ void train_supervised(
305305

306306
model::supervised_train net;
307307
model::supervised_inference inference_net;
308-
dnn_trainer<model::supervised_train, adamw> trainer(net, adamw(0.04, 0.9, 0.999));
308+
dnn_trainer<model::supervised_train, adamw> trainer(net, adamw(0.01, 0.9, 0.999));
309309
trainer.set_learning_rate(learning_rate);
310310
trainer.set_min_learning_rate(min_learning_rate);
311311
trainer.set_mini_batch_size(batch_size);

0 commit comments

Comments
 (0)