1- #![ allow( dead_code, unused ) ]
1+ #![ allow( dead_code) ]
22#![ recursion_limit = "256" ]
33
44extern crate core;
@@ -8,18 +8,16 @@ mod dataset;
88use crate :: data:: { ClassificationBatch , ClassificationBatcher } ;
99use crate :: dataset:: { CLASSES , PlanetLoader , download} ;
1010use bimm:: cache:: disk:: DiskCacheConfig ;
11- use bimm:: models:: resnet:: { PREFAB_RESNET_MAP , ResNet , ResNetContractConfig } ;
11+ use bimm:: models:: resnet:: { PREFAB_RESNET_MAP , ResNet } ;
1212use burn:: backend:: Autodiff ;
1313use burn:: config:: Config ;
1414use burn:: data:: dataloader:: DataLoaderBuilder ;
1515use burn:: data:: dataset:: transform:: ShuffledDataset ;
1616use burn:: data:: dataset:: vision:: ImageFolderDataset ;
1717use burn:: module:: Module ;
18- use burn:: nn:: PReluConfig ;
1918use burn:: nn:: activation:: ActivationConfig ;
2019use burn:: nn:: loss:: BinaryCrossEntropyLossConfig ;
21- use burn:: optim:: AdamConfig ;
22- use burn:: optim:: decay:: WeightDecayConfig ;
20+ use burn:: optim:: AdamWConfig ;
2321use burn:: prelude:: { Int , Tensor } ;
2422use burn:: record:: CompactRecorder ;
2523use burn:: tensor:: backend:: { AutodiffBackend , Backend } ;
@@ -152,6 +150,11 @@ fn main() -> anyhow::Result<()> {
152150 return train :: < Autodiff < burn:: backend:: Metal > > ( & args) ;
153151}
154152
153+ pub fn reset_artifact_dir ( artifact_dir : & str ) -> anyhow:: Result < ( ) > {
154+ std:: fs:: remove_dir_all ( artifact_dir) ?;
155+ std:: fs:: create_dir_all ( artifact_dir) . map_err ( |e| anyhow:: anyhow!( e) )
156+ }
157+
155158#[ must_use]
156159pub fn train < B : AutodiffBackend > ( args : & Args ) -> anyhow:: Result < ( ) > {
157160 let device: B :: Device = Default :: default ( ) ;
@@ -189,8 +192,7 @@ pub fn train<B: AutodiffBackend>(args: &Args) -> anyhow::Result<()> {
189192
190193 // Remove existing artifacts before to get an accurate learner summary
191194 let artifact_dir: & str = args. artifact_dir . as_ref ( ) ;
192- std:: fs:: remove_dir_all ( artifact_dir) ;
193- std:: fs:: create_dir_all ( artifact_dir) . expect ( "Failed to create artifacts directory" ) ;
195+ reset_artifact_dir ( artifact_dir) ?;
194196
195197 B :: seed ( & device, args. seed ) ;
196198
@@ -224,8 +226,9 @@ pub fn train<B: AutodiffBackend>(args: &Args) -> anyhow::Result<()> {
224226 resnet : model,
225227 } ;
226228
227- let optimizer = AdamConfig :: new ( )
228- . with_weight_decay ( Some ( WeightDecayConfig :: new ( args. weight_decay ) ) )
229+ let optimizer = AdamWConfig :: new ( )
230+ . with_cautious_weight_decay ( true )
231+ . with_weight_decay ( args. weight_decay )
229232 . init ( ) ;
230233
231234 LogConfig {
0 commit comments