@@ -64,9 +64,9 @@ def override_config(self, parameters, key=None, value=None, key_value_dict=None)
6464 parameters (`torch.Tensor` or `list(torch.Tensors)`):
6565 The input parameters.
6666 key (`str`):
67- The hyperparamter to override.
67+ The hyperparameter to override.
6868 value:
69- The hyperparameter values .
69+ The hyperparameter value .
7070 key_value_dict (`dict`):
7171 A dictionary with multiple key-values to override.
7272
@@ -115,7 +115,7 @@ def __init__(self, params, defaults, optim_bits=32, is_paged=False):
115115 Base 8-bit optimizer class.
116116
117117 Arguments:
118- params (`torch.tensor `):
118+ params (`torch.Tensor `):
119119 The input parameters to optimize.
120120 optim_bits (`int`, defaults to 32):
121121 The number of bits of the optimizer state.
@@ -291,7 +291,7 @@ def step(self, closure=None):
291291 self .update_step (group , p , gindex , pindex )
292292 torch .cuda .synchronize ()
293293 if self .is_paged :
294- # all paged operation are asynchronous, we need
294+ # all paged operations are asynchronous, we need
295295 # to sync to make sure all tensors are in the right state
296296 torch .cuda .synchronize ()
297297
@@ -371,7 +371,7 @@ def __init__(
371371 Arguments:
372372 optimizer_name (`str`):
373373 The name of the optimizer.
374- params (`torch.tensor `):
374+ params (`torch.Tensor `):
375375 The input parameters to optimize.
376376 lr (`float`, defaults to 1e-3):
377377 The learning rate.
@@ -428,7 +428,6 @@ def __init__(
428428 if args is None :
429429 args = {}
430430 args ["optim_bits" ] = optim_bits
431- args ["percentile_clipping" ] = 100
432431 args ["min_8bit_size" ] = min_8bit_size
433432 args ["percentile_clipping" ] = percentile_clipping
434433 args ["block_wise" ] = block_wise
@@ -613,7 +612,7 @@ def __init__(
613612 Arguments:
614613 optimizer_name (`str`):
615614 The name of the optimizer.
616- params (`torch.tensor `):
615+ params (`torch.Tensor `):
617616 The input parameters to optimize.
618617 lr (`float`, defaults to 1e-3):
619618 The learning rate.
@@ -655,7 +654,6 @@ def __init__(
655654 if args is None :
656655 args = {}
657656 args ["optim_bits" ] = optim_bits
658- args ["percentile_clipping" ] = 100
659657 args ["min_8bit_size" ] = min_8bit_size
660658 args ["percentile_clipping" ] = percentile_clipping
661659 args ["block_wise" ] = block_wise
0 commit comments