@@ -79,7 +79,7 @@ namespace dlib
7979 This demonstrates a modern, clean ViT implementation using:
8080 - patch_embeddings: splits image into patches + learned projection
8181 - canonical_transformer: Dlib's transformer with RoPE positioning
82- - Standard Dlib layers: fc, dropout, avg_pool_everything, ...
82+ - Standard Dlib layers: fc, dropout, ...
8383
8484 Architecture summary:
8585 Input (32x32 RGB) => Patches (4x4) => Embeddings (192-dim)
@@ -100,21 +100,21 @@ namespace dlib
100100 static constexpr long EMBEDDING_DIM = embedding_dim;
101101 static constexpr long PATCH_SIZE = 4 ; // 32/4 = 8x8 = 64 patches
102102 static constexpr long NUM_PATCHES = 64 ; // (32/4)^2
103+ static constexpr long DONT_USE_ClASS_TOKEN = 0 ;
104+ static constexpr long DONT_USE_POSITION_EMBEDDINGS = 0 ;
103105
104106 // Backbone: patch embeddings => transformer => pooling
105107 // Returns: (batch, embedding_dim) feature vectors
106108 template <template <typename > class DO , typename INPUT>
107- using backbone_training =
108- avg_pool_everything<
109+ using backbone_training = rms_norm<
109110 canonical_transformer::transformer_stack<NUM_LAYERS, gelu, DO, EMBEDDING_DIM, NUM_HEADS,
110- patch_embeddings<PATCH_SIZE, EMBEDDING_DIM, 1 , 0 , // cls=1, no pos_emb (use RoPE)
111+ patch_embeddings<PATCH_SIZE, EMBEDDING_DIM, DONT_USE_ClASS_TOKEN, DONT_USE_POSITION_EMBEDDINGS,
111112 INPUT>>>;
112113
113114 template <typename INPUT>
114- using backbone_inference =
115- avg_pool_everything<
115+ using backbone_inference = rms_norm<
116116 canonical_transformer::transformer_stack<NUM_LAYERS, gelu, multiply, EMBEDDING_DIM, NUM_HEADS,
117- patch_embeddings<PATCH_SIZE, EMBEDDING_DIM, 1 , 0 ,
117+ patch_embeddings<PATCH_SIZE, EMBEDDING_DIM, DONT_USE_ClASS_TOKEN, DONT_USE_POSITION_EMBEDDINGS ,
118118 INPUT>>>;
119119
120120 static std::string describe () {
@@ -244,7 +244,7 @@ void train_ssl(
244244 cout << " Training without labels - Learning representations from augmentations\n " << endl;
245245
246246 model::ssl_train net ((loss_barlow_twins_ (lambda)));
247- dnn_trainer<model::ssl_train, adamw> trainer (net, adamw (0.04 , 0.9 , 0.999 ));
247+ dnn_trainer<model::ssl_train, adamw> trainer (net, adamw (0.01 , 0.9 , 0.999 ));
248248 trainer.set_learning_rate (learning_rate);
249249 trainer.set_min_learning_rate (min_learning_rate);
250250 trainer.set_mini_batch_size (batch_size);
@@ -305,7 +305,7 @@ void train_supervised(
305305
306306 model::supervised_train net;
307307 model::supervised_inference inference_net;
308- dnn_trainer<model::supervised_train, adamw> trainer (net, adamw (0.04 , 0.9 , 0.999 ));
308+ dnn_trainer<model::supervised_train, adamw> trainer (net, adamw (0.01 , 0.9 , 0.999 ));
309309 trainer.set_learning_rate (learning_rate);
310310 trainer.set_min_learning_rate (min_learning_rate);
311311 trainer.set_mini_batch_size (batch_size);
0 commit comments