@@ -33,7 +33,7 @@ def _scale_batch_size(
3333 max_trials : int = 25 ,
3434 batch_arg_name : str = "batch_size" ,
3535 margin : float = 0.05 ,
36- max_val : int = 1024 ,
36+ max_val : int = 8192 ,
3737) -> Optional [int ]:
3838 """Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM)
3939 error.
@@ -179,7 +179,7 @@ def _run_power_scaling(
179179 batch_arg_name : str ,
180180 max_trials : int ,
181181 params : dict [str , Any ],
182- max_val : Optional [ int ] ,
182+ max_val : int = 8192 ,
183183) -> int :
184184 """Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
185185 # this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not
@@ -192,7 +192,7 @@ def _run_power_scaling(
192192 # reset after each try
193193 _reset_progress (trainer )
194194
195- if max_val is not None and new_size >= max_val :
195+ if new_size >= max_val :
196196 rank_zero_info (f"Reached the maximum batch size limit of { max_val } . Stopping search." )
197197 break
198198
@@ -235,7 +235,7 @@ def _run_binsearch_scaling(
235235 max_trials : int ,
236236 params : dict [str , Any ],
237237 margin : float ,
238- max_val : Optional [ int ] ,
238+ max_val : int = 8192 ,
239239) -> int :
240240 """Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered.
241241
@@ -252,7 +252,7 @@ def _run_binsearch_scaling(
252252 # reset after each try
253253 _reset_progress (trainer )
254254
255- if max_val is not None and new_size >= max_val :
255+ if new_size >= max_val :
256256 rank_zero_info (f"Reached the maximum batch size limit of { max_val } . Stopping search." )
257257 break
258258
@@ -325,7 +325,7 @@ def _adjust_batch_size(
325325 factor : float = 1.0 ,
326326 value : Optional [int ] = None ,
327327 desc : Optional [str ] = None ,
328- max_val : Optional [ int ] = None ,
328+ max_val : int = 8192 ,
329329) -> tuple [int , bool ]:
330330 """Helper function for adjusting the batch size.
331331
@@ -336,7 +336,7 @@ def _adjust_batch_size(
336336 value: if a value is given, will override the batch size with this value.
337337 Note that the value of `factor` will not have an effect in this case
338338 desc: either ``"succeeded"`` or ``"failed"``. Used purely for logging
339- max_val: Maximum batch size limit. If provided, the new batch size will not exceed this value .
339+ max_val: Maximum batch size limit, to prevent overly large or inefficient batch sizes .
340340
341341 Returns:
342342 The new batch size for the next trial and a bool that signals whether the
@@ -367,7 +367,7 @@ def _adjust_batch_size(
367367 new_size = value if value is not None else int (batch_size * factor )
368368
369369 # Apply max_val limit if provided
370- if max_val is not None and new_size > max_val :
370+ if new_size > max_val :
371371 if desc :
372372 rank_zero_info (f"Batch size { new_size } exceeds max_val limit { max_val } , capping at { max_val } " )
373373 new_size = max_val
0 commit comments