@@ -16,7 +16,9 @@ use burn::data::dataset::transform::ShuffledDataset;
1616use burn:: data:: dataset:: vision:: ImageFolderDataset ;
1717use burn:: lr_scheduler:: cosine:: CosineAnnealingLrSchedulerConfig ;
1818use burn:: module:: Module ;
19+ use burn:: nn:: activation:: ActivationConfig ;
1920use burn:: nn:: loss:: BinaryCrossEntropyLossConfig ;
21+ use burn:: nn:: { LeakyReluConfig , PReluConfig } ;
2022use burn:: optim:: AdamWConfig ;
2123use burn:: prelude:: { Int , Tensor } ;
2224use burn:: record:: CompactRecorder ;
@@ -33,8 +35,9 @@ use burn::train::{
3335 LearnerBuilder , LearningStrategy , MetricEarlyStoppingStrategy , MultiLabelClassificationOutput ,
3436 StoppingCondition , TrainOutput , TrainStep , ValidStep ,
3537} ;
36- use clap:: { Parser , arg} ;
38+ use clap:: { Parser , ValueEnum , arg} ;
3739use core:: clone:: Clone ;
40+ use serde:: { Deserialize , Serialize } ;
3841use std:: time:: Instant ;
3942/*
4043tracel-ai/models reference:
@@ -46,6 +49,13 @@ tracel-ai/models reference:
4649| Valid | Loss | 0.168 | 3 | 0.512 | 1 |
4750 */
4851
52+ #[ derive( Debug , Clone , Copy , PartialEq , Eq , Serialize , Deserialize , ValueEnum ) ]
53+ pub enum ReplaceActivationOption {
54+ Gelu ,
55+ PRelu ,
56+ LeakyRelu ,
57+ }
58+
4959#[ derive( Parser , Debug ) ]
5060#[ command( author, version, about, long_about = None ) ]
5161pub struct Args {
@@ -78,40 +88,44 @@ pub struct Args {
7888 pub num_workers : usize ,
7989
8090 /// Number of epochs to train the model.
81- #[ arg( long, default_value = "200 " ) ]
91+ #[ arg( long, default_value = "60 " ) ]
8292 pub num_epochs : usize ,
8393
94+ /// Early stopping patience
95+ #[ arg( long, default_value_t = 20 ) ]
96+ pub patience : usize ,
97+
8498 /// Pretrained Resnet Model.
8599 /// Use "list" to list all available pretrained models.
86100 #[ arg( long, default_value = "resnet50.tv_in1k" ) ]
87101 pub pretrained : String ,
88102
103+ /// Replace activation function?
104+ #[ arg( long, default_value = "None" ) ]
105+ pub replace_activation : Option < ReplaceActivationOption > ,
106+
89107 /// Freeze the body layers during training.
90108 #[ arg( long, default_value = "false" ) ]
91109 pub freeze_layers : bool ,
92110
93111 /// Drop Block Prob
94- #[ arg( long, default_value = "0.15 " ) ]
112+ #[ arg( long, default_value = "0.2 " ) ]
95113 pub drop_block_prob : f64 ,
96114
97115 /// Drop Path Prob
98116 #[ arg( long, default_value = "0.05" ) ]
99117 pub stochastic_depth_prob : f64 ,
100118
101119 /// Learning rate
102- #[ arg( long, default_value_t = 5e-4 ) ]
120+ #[ arg( long, default_value_t = 5e-5 ) ]
103121 pub learning_rate : f64 ,
104122
105- /// Early stopping patience
106- #[ arg( long, default_value_t = 20 ) ]
107- pub patience : usize ,
108-
109123 /// Enable cautious weight decay.
110124 #[ arg( long, default_value = "false" ) ]
111125 pub cautious_weight_decay : bool ,
112126
113127 /// Optimizer Weight decay.
114- #[ arg( long, default_value_t = 5e-3 ) ]
128+ #[ arg( long, default_value_t = 2e-2 ) ]
115129 pub weight_decay : f32 ,
116130}
117131
@@ -212,7 +226,23 @@ pub fn train<B: AutodiffBackend>(args: &Args) -> anyhow::Result<()> {
212226 . fetch_weights ( & disk_cache)
213227 . expect ( "Failed to fetch pretrained weights" ) ;
214228
215- let resnet_config = prefab. to_config ( ) ;
229+ let mut resnet_config = prefab. to_config ( ) ;
230+
231+ if let Some ( option) = & args. replace_activation {
232+ match option {
233+ ReplaceActivationOption :: Gelu => {
234+ resnet_config = resnet_config. with_activation ( ActivationConfig :: Gelu ) ;
235+ }
236+ ReplaceActivationOption :: PRelu => {
237+ resnet_config =
238+ resnet_config. with_activation ( ActivationConfig :: PRelu ( PReluConfig :: new ( ) ) ) ;
239+ }
240+ ReplaceActivationOption :: LeakyRelu => {
241+ resnet_config = resnet_config
242+ . with_activation ( ActivationConfig :: LeakyRelu ( LeakyReluConfig :: new ( ) ) ) ;
243+ }
244+ }
245+ }
216246
217247 let mut model: ResNet < B > = resnet_config
218248 . clone ( )
@@ -283,50 +313,55 @@ pub fn train<B: AutodiffBackend>(args: &Args) -> anyhow::Result<()> {
283313 . init ( )
284314 . expect ( "Failed to initialize learning rate scheduler" ) ;
285315
286- // Learner config
287- let learner = LearnerBuilder :: new ( artifact_dir)
288- . metric_train_numeric ( HammingScore :: new ( ) )
289- . metric_valid_numeric ( HammingScore :: new ( ) )
290- . metric_train_numeric ( LossMetric :: new ( ) )
291- . metric_valid_numeric ( LossMetric :: new ( ) )
292- . metric_train ( CudaMetric :: new ( ) )
293- . metric_valid ( CudaMetric :: new ( ) )
294- . metric_train_numeric ( CpuUse :: new ( ) )
295- . metric_valid_numeric ( CpuUse :: new ( ) )
296- . metric_train_numeric ( CpuMemory :: new ( ) )
297- . metric_valid_numeric ( CpuMemory :: new ( ) )
298- . metric_train_numeric ( LearningRateMetric :: new ( ) )
299- . with_file_checkpointer ( CompactRecorder :: new ( ) )
300- . early_stopping ( MetricEarlyStoppingStrategy :: new (
301- & LossMetric :: < B > :: new ( ) ,
302- Aggregate :: Mean ,
303- Direction :: Lowest ,
304- Split :: Valid ,
305- StoppingCondition :: NoImprovementSince {
306- n_epochs : args. patience ,
307- } ,
308- ) )
309- . learning_strategy ( LearningStrategy :: SingleDevice ( device. clone ( ) ) )
310- . grads_accumulation ( args. grads_accumulation )
311- . num_epochs ( args. num_epochs )
312- . summary ( )
313- /*
314- .renderer(CustomRenderer {})
315- .with_application_logger(None)
316- */
317- . build ( host, optimizer, lr_scheduler) ;
318-
319- // Training
320- let now = Instant :: now ( ) ;
321- let model_trained = learner. fit ( dataloader_train, dataloader_test) ;
316+ let now: Instant ;
317+ {
318+ // Learner config
319+ let learner = LearnerBuilder :: new ( artifact_dir)
320+ . metric_train_numeric ( HammingScore :: new ( ) )
321+ . metric_valid_numeric ( HammingScore :: new ( ) )
322+ . metric_train_numeric ( LossMetric :: new ( ) )
323+ . metric_valid_numeric ( LossMetric :: new ( ) )
324+ . metric_train ( CudaMetric :: new ( ) )
325+ . metric_valid ( CudaMetric :: new ( ) )
326+ . metric_train_numeric ( CpuUse :: new ( ) )
327+ . metric_valid_numeric ( CpuUse :: new ( ) )
328+ . metric_train_numeric ( CpuMemory :: new ( ) )
329+ . metric_valid_numeric ( CpuMemory :: new ( ) )
330+ . metric_train_numeric ( LearningRateMetric :: new ( ) )
331+ . with_file_checkpointer ( CompactRecorder :: new ( ) )
332+ . early_stopping ( MetricEarlyStoppingStrategy :: new (
333+ & LossMetric :: < B > :: new ( ) ,
334+ Aggregate :: Mean ,
335+ Direction :: Lowest ,
336+ Split :: Valid ,
337+ StoppingCondition :: NoImprovementSince {
338+ n_epochs : args. patience ,
339+ } ,
340+ ) )
341+ . learning_strategy ( LearningStrategy :: SingleDevice ( device. clone ( ) ) )
342+ . grads_accumulation ( args. grads_accumulation )
343+ . num_epochs ( args. num_epochs )
344+ . summary ( )
345+ /*
346+ .renderer(CustomRenderer {})
347+ .with_application_logger(None)
348+ */
349+ . build ( host, optimizer, lr_scheduler) ;
350+
351+ // Training
352+ now = Instant :: now ( ) ;
353+ let model_trained = learner. fit ( dataloader_train, dataloader_test) ;
354+
355+ model_trained
356+ . model
357+ . resnet
358+ . save_file ( format ! ( "{artifact_dir}/model" ) , & CompactRecorder :: new ( ) )
359+ . expect ( "Trained model should be saved successfully" ) ;
360+ }
322361 let elapsed = now. elapsed ( ) . as_secs ( ) ;
323362 println ! ( "Training completed in {}m{}s" , ( elapsed / 60 ) , elapsed % 60 ) ;
324363
325- model_trained
326- . model
327- . resnet
328- . save_file ( format ! ( "{artifact_dir}/model" ) , & CompactRecorder :: new ( ) )
329- . expect ( "Trained model should be saved successfully" ) ;
364+ println ! ( "{:#?}" , args) ;
330365
331366 Ok ( ( ) )
332367}
0 commit comments