Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions bitsandbytes/optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
Base AdamW optimizer.

Arguments:
params (`torch.tensor`):
params (`torch.Tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(
8-bit AdamW optimizer.

Arguments:
params (`torch.tensor`):
params (`torch.Tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
Expand Down Expand Up @@ -159,7 +159,7 @@ def __init__(
32-bit AdamW optimizer.

Arguments:
params (`torch.tensor`):
params (`torch.Tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
Expand Down Expand Up @@ -219,7 +219,7 @@ def __init__(
Paged AdamW optimizer.

Arguments:
params (`torch.tensor`):
params (`torch.Tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
Expand All @@ -241,8 +241,6 @@ def __init__(
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__(
"adam",
Expand Down Expand Up @@ -279,7 +277,7 @@ def __init__(
Paged 8-bit AdamW optimizer.

Arguments:
params (`torch.tensor`):
params (`torch.Tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
Expand All @@ -303,8 +301,6 @@ def __init__(
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
# Validate unsupported parameters
if amsgrad:
Expand Down Expand Up @@ -350,7 +346,7 @@ def __init__(
Paged 32-bit AdamW optimizer.

Arguments:
params (`torch.tensor`):
params (`torch.Tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
Expand All @@ -372,8 +368,6 @@ def __init__(
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__(
"adam",
Expand Down