@@ -152,8 +152,7 @@ def train(
152152        min_epochs : int  =  1 ,
153153        seed : int  |  None  =  None ,
154154        deterministic : bool  |  Literal ["warn" ] =  False ,
155-         precision : _PRECISION_INPUT  |  None  =  None ,
156-         val_check_interval : int  |  float  |  None  =  None ,
155+         precision : _PRECISION_INPUT  |  None  =  16 ,
157156        callbacks : list [Callback ] |  Callback  |  None  =  None ,
158157        logger : Logger  |  Iterable [Logger ] |  bool  |  None  =  None ,
159158        resume : bool  =  False ,
@@ -162,7 +161,7 @@ def train(
162161        adaptive_bs : Literal ["None" , "Safe" , "Full" ] =  "None" ,
163162        check_val_every_n_epoch : int  |  None  =  1 ,
164163        num_sanity_val_steps : int  |  None  =  0 ,
165-         log_every_n_steps :  int  |  None  =  1 ,
164+         gradient_clip_val :  float  |  None  =  None ,
166165        ** kwargs ,
167166    ) ->  dict [str , Any ]:
168167        r"""Trains the model using the provided LightningModule and OTXDataModule. 
@@ -175,7 +174,6 @@ def train(
175174                Also, can be set to `warn` to avoid failures, because some operations don't 
176175                support deterministic mode. Defaults to False. 
177176            precision (_PRECISION_INPUT | None, optional): The precision of the model. Defaults to 16. 
178-             val_check_interval (int | float | None, optional): The validation check interval. Defaults to None. 
179177            callbacks (list[Callback] | Callback | None, optional): The callbacks to be used during training. 
180178            logger (Logger | Iterable[Logger] | bool | None, optional): The logger(s) to be used. Defaults to None. 
181179            resume (bool, optional): If True, tries to resume training from existing checkpoint. 
@@ -188,6 +186,7 @@ def train(
188186                Defaults to "None". 
189187            check_val_every_n_epoch (int | None, optional): How often to check validation. Defaults to 1. 
190188            num_sanity_val_steps (int | None, optional): Number of validation steps to run before training starts. 
189+             gradient_clip_val (float | None, optional): The value for gradient clipping. Defaults to None. 
191190            **kwargs: Additional keyword arguments for pl.Trainer configuration. 
192191
193192        Returns: 
@@ -243,10 +242,9 @@ def train(
243242            max_epochs = max_epochs ,
244243            min_epochs = min_epochs ,
245244            deterministic = deterministic ,
246-             val_check_interval = val_check_interval ,
247245            check_val_every_n_epoch = check_val_every_n_epoch ,
248246            num_sanity_val_steps = num_sanity_val_steps ,
249-             log_every_n_steps = log_every_n_steps ,
247+             gradient_clip_val = gradient_clip_val ,
250248            ** kwargs ,
251249        )
252250        fit_kwargs : dict [str , Any ] =  {}
@@ -877,13 +875,18 @@ def _apply_param_overrides(self, param_kwargs: dict[str, Any]) -> None:
877875        """Apply parameter overrides based on the current local variables.""" 
878876        sig  =  inspect .signature (self .train )
879877        add_kwargs  =  param_kwargs .pop ("kwargs" , {})
880-         self ._cache .update (** add_kwargs )
881878        for  param_name , param  in  sig .parameters .items ():
882-             if  param_name  in  param_kwargs :
883-                 current_value  =  param_kwargs [param_name ]
884-                 # Apply override if current value matches default and we have an override 
885-                 if  (current_value  !=  param .default ) or  (param_name  not  in   self ._cache .args ):
879+             if  param_name  in  param_kwargs  and  param_name  in  self ._cache .args :
880+                 # if both `param_kwargs` and `_cache.args` have the same parameter, 
881+                 # we will use the value from `param_kwargs` if it is different from the default 
882+                 # value of the parameter. 
883+                 # Otherwise, we will keep the value from `_cache.args`. 
884+                 current_value  =  param_kwargs .pop (param_name )
885+                 if  current_value  !=  param .default :
886886                    self ._cache .args [param_name ] =  current_value 
887+         # update the cache with the remaining parameters 
888+         self ._cache .update (** param_kwargs )
889+         self ._cache .update (** add_kwargs )
887890
888891    def  configure_accelerator (self ) ->  None :
889892        """Updates the cache arguments based on the device type.""" 
0 commit comments