@@ -9,7 +9,7 @@ use crate::data::{ClassificationBatch, ClassificationBatcher};
99use crate :: dataset:: { CLASSES , PlanetLoader , download} ;
1010use bimm:: cache:: disk:: DiskCacheConfig ;
1111use bimm:: compat:: activation_wrapper:: ActivationConfig ;
12- use bimm:: models:: resnet:: { PREFAB_RESNET_MAP , ResNet } ;
12+ use bimm:: models:: resnet:: { PREFAB_RESNET_MAP , ResNet , ResNetContractConfig } ;
1313use burn:: backend:: { Autodiff , Cuda } ;
1414use burn:: config:: Config ;
1515use burn:: data:: dataloader:: DataLoaderBuilder ;
@@ -62,6 +62,14 @@ $ --drop-path-prob=0.15 --drop-block-prob=0.25 --learning-rate=1e-5
6262#[ derive( Parser , Debug ) ]
6363#[ command( author, version, about, long_about = None ) ]
6464pub struct Args {
65+ /// Random seed for reproducibility.
66+ #[ arg( short, long, default_value = "0" ) ]
67+ seed : u64 ,
68+
69+ /// Train percentage.
70+ #[ arg( long, default_value = "70" ) ]
71+ pub train_percentage : u8 ,
72+
6573 /// Directory to save the artifacts.
6674 #[ arg( long, default_value = "/tmp/resnet-finetune" ) ]
6775 artifact_dir : String ,
@@ -72,7 +80,7 @@ pub struct Args {
7280
7381 /// Number of workers for data loading.
7482 #[ arg( long, default_value = "2" ) ]
75- num_workers : Option < usize > ,
83+ num_workers : usize ,
7684
7785 /// Number of epochs to train the model.
7886 #[ arg( long, default_value = "60" ) ]
@@ -101,10 +109,11 @@ pub struct Args {
101109 /// Early stopping patience
102110 #[ arg( long, default_value = "10" ) ]
103111 patience : usize ,
104- }
105112
106- #[ allow( dead_code) ]
107- const ARTIFACT_DIR : & str = "/tmp/resnet-finetune" ;
113+ /// Optimizer Weight decay.
114+ #[ arg( long, default_value_t = 5e-4 ) ]
115+ pub weight_decay : f32 ,
116+ }
108117
109118fn main ( ) {
110119 let args = Args :: parse ( ) ;
@@ -115,101 +124,16 @@ fn main() {
115124 train :: < Autodiff < Cuda > > ( & args, & device) ;
116125}
117126
118- #[ allow( dead_code) ]
119- fn run < B : Backend > (
120- args : & Args ,
121- device : & B :: Device ,
122- ) {
123- train :: < Autodiff < B > > ( args, device) ;
124- // infer::<B>(ARTIFACT_DIR, device, 0.5);
125- }
126-
127- pub trait MultiLabelClassification < B : Backend > {
128- fn forward_classification (
129- & self ,
130- images : Tensor < B , 4 > ,
131- targets : Tensor < B , 2 , Int > ,
132- ) -> MultiLabelClassificationOutput < B > ;
133- }
134-
135- impl < B : Backend > MultiLabelClassification < B > for ResNet < B > {
136- fn forward_classification (
137- & self ,
138- images : Tensor < B , 4 > ,
139- targets : Tensor < B , 2 , Int > ,
140- ) -> MultiLabelClassificationOutput < B > {
141- let output = self . forward ( images) ;
142- let loss = BinaryCrossEntropyLossConfig :: new ( )
143- . with_logits ( true )
144- . init ( & output. device ( ) )
145- . forward ( output. clone ( ) , targets. clone ( ) ) ;
146-
147- MultiLabelClassificationOutput :: new ( loss, output, targets)
148- }
149- }
150-
151- impl < B : AutodiffBackend > TrainStep < ClassificationBatch < B > , MultiLabelClassificationOutput < B > >
152- for ResNet < B >
153- {
154- fn step (
155- & self ,
156- batch : ClassificationBatch < B > ,
157- ) -> TrainOutput < MultiLabelClassificationOutput < B > > {
158- let item = self . forward_classification ( batch. images , batch. targets ) ;
159-
160- TrainOutput :: new ( self , item. loss . backward ( ) , item)
161- }
162- }
163-
164- impl < B : Backend > ValidStep < ClassificationBatch < B > , MultiLabelClassificationOutput < B > >
165- for ResNet < B >
166- {
167- fn step (
168- & self ,
169- batch : ClassificationBatch < B > ,
170- ) -> MultiLabelClassificationOutput < B > {
171- self . forward_classification ( batch. images , batch. targets )
172- }
173- }
174-
175- #[ derive( Config ) ]
176- pub struct TrainingConfig {
177- #[ config( default = 5 ) ]
178- pub num_epochs : usize ,
179-
180- #[ config( default = 24 ) ]
181- pub batch_size : usize ,
182-
183- #[ config( default = 4 ) ]
184- pub num_workers : usize ,
185-
186- #[ config( default = 42 ) ]
187- pub seed : u64 ,
188-
189- #[ config( default = 1e-3 ) ]
190- pub learning_rate : f64 ,
191-
192- #[ config( default = 5e-5 ) ]
193- pub weight_decay : f32 ,
194-
195- #[ config( default = 70 ) ]
196- pub train_percentage : u8 ,
197-
198- pub num_classes : usize ,
199- }
200-
201- fn create_artifact_dir ( artifact_dir : & str ) {
202- // Remove existing artifacts before to get an accurate learner summary
203- std:: fs:: remove_dir_all ( artifact_dir) . ok ( ) ;
204- std:: fs:: create_dir_all ( artifact_dir) . ok ( ) ;
205- }
206-
207127pub fn train < B : AutodiffBackend > (
208128 args : & Args ,
209129 device : & B :: Device ,
210130) -> anyhow:: Result < ( ) > {
211- let artifact_dir = args. artifact_dir . as_ref ( ) ;
212- create_artifact_dir ( artifact_dir) ;
131+ // Remove existing artifacts before to get an accurate learner summary
132+ let artifact_dir: & str = args. artifact_dir . as_ref ( ) ;
133+ std:: fs:: remove_dir_all ( artifact_dir) ?;
134+ std:: fs:: create_dir_all ( artifact_dir) ?;
135+
136+ B :: seed ( args. seed ) ;
213137
214138 let disk_cache = DiskCacheConfig :: default ( ) ;
215139
@@ -220,54 +144,70 @@ pub fn train<B: AutodiffBackend>(
220144 . fetch_weights ( & disk_cache)
221145 . expect ( "Failed to fetch pretrained weights" ) ;
222146
223- let resnet_config = prefab
224- . to_config ( )
225- . with_activation ( ActivationConfig :: Gelu )
226- . to_structure ( ) ;
147+ let resnet_config = prefab. to_config ( ) . with_activation ( ActivationConfig :: Gelu ) ;
227148
228149 let model: ResNet < B > = resnet_config
150+ . clone ( )
151+ . to_structure ( )
229152 . init ( device)
230153 . load_pytorch_weights ( weights)
231154 . expect ( "Failed to load pretrained weights" )
232155 . with_classes ( CLASSES . len ( ) )
233156 . with_stochastic_drop_block ( args. drop_block_prob )
234157 . with_stochastic_path_depth ( args. drop_path_prob ) ;
235158
236- // Config
237- let training_config = TrainingConfig :: new ( CLASSES . len ( ) )
238- . with_learning_rate ( args. learning_rate )
239- . with_num_epochs ( args. num_epochs )
240- . with_batch_size ( args. batch_size ) ;
241-
242159 let optimizer = AdamConfig :: new ( )
243- . with_weight_decay ( Some ( WeightDecayConfig :: new ( training_config . weight_decay ) ) )
160+ . with_weight_decay ( Some ( WeightDecayConfig :: new ( args . weight_decay ) ) )
244161 . init ( ) ;
245162
246- training_config
247- . save ( format ! ( "{artifact_dir}/config.json" ) )
248- . expect ( "Config should be saved successfully" ) ;
249-
250- B :: seed ( training_config. seed ) ;
163+ #[ derive( Config ) ]
164+ struct LogConfig {
165+ seed : u64 ,
166+ train_percentage : u8 ,
167+ batch_size : usize ,
168+ num_epochs : usize ,
169+ resnet_prefab : String ,
170+ resnet_pretrained : String ,
171+ drop_block_prob : f64 ,
172+ drop_path_prob : f64 ,
173+ learning_rate : f64 ,
174+ patience : usize ,
175+ weight_decay : f32 ,
176+ resnet : ResNetContractConfig ,
177+ }
178+ LogConfig {
179+ seed : args. seed ,
180+ train_percentage : args. train_percentage ,
181+ batch_size : args. batch_size ,
182+ num_epochs : args. num_epochs ,
183+ resnet_prefab : args. resnet_prefab . clone ( ) ,
184+ resnet_pretrained : args. resnet_pretrained . clone ( ) ,
185+ drop_block_prob : args. drop_block_prob ,
186+ drop_path_prob : args. drop_path_prob ,
187+ learning_rate : args. learning_rate ,
188+ patience : args. patience ,
189+ weight_decay : args. weight_decay ,
190+ resnet : resnet_config,
191+ }
192+ . save ( format ! ( "{artifact_dir}/config.json" ) )
193+ . expect ( "Config should be saved successfully" ) ;
251194
252195 // Dataloaders
253196 let batcher_train = ClassificationBatcher :: < B > :: new ( device. clone ( ) ) ;
254197 let batcher_valid = ClassificationBatcher :: < B :: InnerBackend > :: new ( device. clone ( ) ) ;
255198
256- let ( train, valid) = ImageFolderDataset :: planet_train_val_split (
257- training_config. train_percentage ,
258- training_config. seed ,
259- )
260- . unwrap ( ) ;
199+ let ( train, valid) =
200+ ImageFolderDataset :: planet_train_val_split ( args. train_percentage , args. seed ) . unwrap ( ) ;
261201
262202 let dataloader_train = DataLoaderBuilder :: new ( batcher_train)
263- . batch_size ( training_config . batch_size )
264- . shuffle ( training_config . seed )
265- . num_workers ( training_config . num_workers )
266- . build ( ShuffledDataset :: with_seed ( train, training_config . seed ) ) ;
203+ . batch_size ( args . batch_size )
204+ . shuffle ( args . seed )
205+ . num_workers ( args . num_workers )
206+ . build ( ShuffledDataset :: with_seed ( train, args . seed ) ) ;
267207
268208 let dataloader_test = DataLoaderBuilder :: new ( batcher_valid)
269- . batch_size ( training_config . batch_size )
270- . num_workers ( training_config . num_workers )
209+ . batch_size ( args . batch_size )
210+ . num_workers ( args . num_workers )
271211 . build ( valid) ;
272212
273213 // Learner config
@@ -287,9 +227,9 @@ pub fn train<B: AutodiffBackend>(
287227 } ,
288228 ) )
289229 . devices ( vec ! [ device. clone( ) ] )
290- . num_epochs ( training_config . num_epochs )
230+ . num_epochs ( args . num_epochs )
291231 . summary ( )
292- . build ( model, optimizer, training_config . learning_rate ) ;
232+ . build ( model, optimizer, args . learning_rate ) ;
293233
294234 // Training
295235 let now = Instant :: now ( ) ;
@@ -303,3 +243,51 @@ pub fn train<B: AutodiffBackend>(
303243
304244 Ok ( ( ) )
305245}
246+
247+ pub trait MultiLabelClassification < B : Backend > {
248+ fn forward_classification (
249+ & self ,
250+ images : Tensor < B , 4 > ,
251+ targets : Tensor < B , 2 , Int > ,
252+ ) -> MultiLabelClassificationOutput < B > ;
253+ }
254+
255+ impl < B : Backend > MultiLabelClassification < B > for ResNet < B > {
256+ fn forward_classification (
257+ & self ,
258+ images : Tensor < B , 4 > ,
259+ targets : Tensor < B , 2 , Int > ,
260+ ) -> MultiLabelClassificationOutput < B > {
261+ let output = self . forward ( images) ;
262+ let loss = BinaryCrossEntropyLossConfig :: new ( )
263+ . with_logits ( true )
264+ . init ( & output. device ( ) )
265+ . forward ( output. clone ( ) , targets. clone ( ) ) ;
266+
267+ MultiLabelClassificationOutput :: new ( loss, output, targets)
268+ }
269+ }
270+
271+ impl < B : AutodiffBackend > TrainStep < ClassificationBatch < B > , MultiLabelClassificationOutput < B > >
272+ for ResNet < B >
273+ {
274+ fn step (
275+ & self ,
276+ batch : ClassificationBatch < B > ,
277+ ) -> TrainOutput < MultiLabelClassificationOutput < B > > {
278+ let item = self . forward_classification ( batch. images , batch. targets ) ;
279+
280+ TrainOutput :: new ( self , item. loss . backward ( ) , item)
281+ }
282+ }
283+
284+ impl < B : Backend > ValidStep < ClassificationBatch < B > , MultiLabelClassificationOutput < B > >
285+ for ResNet < B >
286+ {
287+ fn step (
288+ & self ,
289+ batch : ClassificationBatch < B > ,
290+ ) -> MultiLabelClassificationOutput < B > {
291+ self . forward_classification ( batch. images , batch. targets )
292+ }
293+ }
0 commit comments