2828def adapt_batch_size (
2929 engine : OTXEngine ,
3030 not_increase : bool = True ,
31- callbacks : list [Callback ] | Callback | None = None ,
3231 ** train_args ,
3332) -> None :
3433 """Change the actual batch size depending on the current GPU status.
@@ -39,7 +38,6 @@ def adapt_batch_size(
3938 Args:
4039 engine (OTXEngine): engine instnace.
4140 not_increase (bool) : Whether adapting batch size to larger value than default value or not.
42- callbacks (list[Callback] | Callback | None, optional): callbacks used during training. Defaults to None.
4341 """
4442 if not (is_cuda_available () or is_xpu_available ()):
4543 msg = "Adaptive batch size supports only CUDA or XPU."
@@ -55,7 +53,7 @@ def adapt_batch_size(
5553 _apply_new_batch_size (engine , new_batch_size )
5654 return
5755
58- train_func = partial (_train_model , engine = engine , callbacks = callbacks , ** _adjust_train_args (train_args ))
56+ train_func = partial (_train_model , engine = engine , ** _adjust_train_args (train_args ))
5957 bs_search_algo = BsSearchAlgo (
6058 train_func = train_func ,
6159 default_bs = default_bs ,
@@ -85,11 +83,12 @@ def adapt_batch_size(
8583def _adjust_train_args (train_args : dict [str , Any ]) -> dict [str , Any ]:
8684 train_args .update (train_args .pop ("kwargs" , {}))
8785 train_args .pop ("self" , None )
88- train_args .pop ("adaptive_bs" )
86+ train_args .pop ("adaptive_bs" , None )
87+ train_args .pop ("callbacks" , None )
8988 return train_args
9089
9190
92- def _train_model (bs : int , engine : OTXEngine , callbacks : list [ Callback ] | Callback | None = None , ** train_args ) -> None :
91+ def _train_model (bs : int , engine : OTXEngine , ** train_args ) -> None :
9392 if bs <= 0 :
9493 msg = f"Batch size should be greater than 0, but { bs } is given."
9594 raise ValueError (msg )
@@ -100,7 +99,8 @@ def _train_model(bs: int, engine: OTXEngine, callbacks: list[Callback] | Callbac
10099 engine .datamodule .val_subset .batch_size = bs
101100 engine .datamodule .test_subset .batch_size = bs
102101 train_args ["adaptive_bs" ] = "None"
103- engine .train (callbacks = _register_callback (callbacks ), ** train_args )
102+ print (f"Runnning training trial with bs = { bs } ..." )
103+ engine .train (callbacks = _register_callback (), ** train_args )
104104
105105
106106def _register_callback (callbacks : list [Callback ] | Callback | None = None ) -> list [Callback ]:
@@ -114,9 +114,13 @@ def _register_callback(callbacks: list[Callback] | Callback | None = None) -> li
114114
115115def _apply_new_batch_size (engine : OTXEngine , new_batch_size : int ) -> None :
116116 origin_bs = engine .datamodule .train_subset .batch_size
117+ if is_xpu_available () and new_batch_size != 1 :
118+ new_batch_size -= 1 # for safety reasons
117119 if new_batch_size == origin_bs :
118120 return
119121 engine .datamodule .train_subset .batch_size = new_batch_size
120122 engine .datamodule .val_subset .batch_size = new_batch_size
121123 engine .datamodule .test_subset .batch_size = new_batch_size
122- engine .model .optimizer_callable .optimizer_kwargs ["lr" ] *= sqrt (new_batch_size / origin_bs ) # type: ignore[attr-defined]
124+ new_lr = engine .model .optimizer_callable .optimizer_kwargs ["lr" ] * sqrt (new_batch_size / origin_bs ) # type: ignore[attr-defined]
125+ print (f"new batch size = { new_batch_size } with learning rate = { new_lr } is set for the training and validation." )
126+ engine .model .optimizer_callable .optimizer_kwargs ["lr" ] = new_lr # type: ignore[attr-defined]
0 commit comments