@@ -31,6 +31,8 @@ def __init__(
3131        scheduler : str  =  "reduce_on_plateau" ,
3232        scheduler_params : Dict [str , Any ] =  None ,
3333        log_freq : int  =  100 ,
34+         auto_lr_finder : bool  =  False ,
35+         ** kwargs ,
3436    ) ->  None :
3537        """Segmentation model training experiment. 
3638
@@ -64,6 +66,7 @@ def __init__(
6466                optim paramas like learning rates, weight decays etc for diff parts of 
6567                the network. 
6668                E.g. {"encoder": {"weight_decay: 0.1, "lr": 0.1}, "sem": {"lr": 0.01}} 
69+                 or {"learning_rate": 0.005, "weight_decay": 0.03} 
6770            lookahead : bool, default=False 
6871                Flag whether the optimizer uses lookahead. 
6972            scheduler : str, default="reduce_on_plateau" 
@@ -75,6 +78,8 @@ def __init__(
7578                for the possible scheduler arguments. 
7679            log_freq : int, default=100 
7780                Return logs every n batches in logging callbacks. 
81+             auto_lr_finder : bool, default=False 
82+                 Flag, whether to use the lightning in-built auto-lr-finder. 
7883
7984        Raises 
8085        ------ 
@@ -83,6 +88,8 @@ def __init__(
8388            ValueError if illegal metric names are given. 
8489            ValueError if illegal optimizer name is given. 
8590            ValueError if illegal scheduler name is given. 
91+             KeyError if `auto_lr_finder` is set to True and `optim_params` does not 
92+                 contain `lr`-key. 
8693        """ 
8794        super ().__init__ ()
8895        self .model  =  model 
@@ -95,6 +102,16 @@ def __init__(
95102        self .scheduler  =  scheduler 
96103        self .scheduler_params  =  scheduler_params 
97104        self .lookahead  =  lookahead 
105+         self .auto_lr_finder  =  auto_lr_finder 
106+ 
107+         if  auto_lr_finder :
108+             try :
109+                 self .lr  =  optim_params ["lr" ]
110+             except  KeyError :
111+                 raise  KeyError (
112+                     "To use lightning in-built auto_lr_finder, the `optim_params` " 
113+                     "config variable has to contain 'lr'-key for learning-rate." 
114+                 )
98115
99116        self .branch_losses  =  branch_losses 
100117        self .branch_metrics  =  branch_metrics 
@@ -309,15 +326,20 @@ def configure_optimizers(self):
309326                f"Illegal scheduler given. Got { self .scheduler }  . Allowed: { allowed }  ." 
310327            )
311328
312-         # set sensible default if None. 
313-         if  self .optim_params  is  None :
314-             self .optim_params  =  {
315-                 "encoder" : {"lr" : 0.00005 , "weight_decay" : 0.00003 },
316-                 "decoder" : {"lr" : 0.0005 , "weight_decay" : 0.0003 },
317-             }
329+         if  not  self .auto_lr_finder :
330+             # set sensible default if None. 
331+             if  self .optim_params  is  None :
332+                 self .optim_params  =  {
333+                     "encoder" : {"lr" : 0.00005 , "weight_decay" : 0.00005 },
334+                     "decoder" : {"lr" : 0.0005 , "weight_decay" : 0.0005 },
335+                 }
318336
319-         params  =  adjust_optim_params (self .model , self .optim_params )
320-         optimizer  =  OPTIM_LOOKUP [self .optimizer ](params )
337+             params  =  adjust_optim_params (self .model , self .optim_params )
338+             optimizer  =  OPTIM_LOOKUP [self .optimizer ](params )
339+         else :
340+             optimizer  =  OPTIM_LOOKUP [self .optimizer ](
341+                 self .model .parameters (), lr = self .lr 
342+             )
321343
322344        if  self .lookahead :
323345            optimizer  =  OPTIM_LOOKUP ["lookahead" ](optimizer , k = 5 , alpha = 0.5 )
0 commit comments