Skip to content

Commit 5486053

Browse files
committed
type hints in Params4bit constructors
1 parent 74c00eb commit 5486053

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

bitsandbytes/nn/modules.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5-
from typing import Optional, TypeVar, Union, overload
5+
from typing import Any, Dict, Optional, TypeVar, Union, overload
66

77
import warnings
88
import torch
@@ -142,7 +142,7 @@ def forward(self, input: Tensor) -> Tensor:
142142

143143
class Params4bit(torch.nn.Parameter):
144144

145-
def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'):
145+
def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad=True, quant_state: QuantState = None, blocksize: int = 64, compress_statistics: bool = True, quant_type: str = 'fp4') -> "Params4bit":
146146
if data is None:
147147
data = torch.empty(0)
148148

@@ -155,7 +155,7 @@ def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64,
155155
return self
156156

157157
@classmethod
158-
def from_prequantized(cls, data, quantized_stats, requires_grad=False, device='cuda', **kwargs):
158+
def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit":
159159
self = torch.Tensor._make_subclass(cls, data.to(device))
160160
self.requires_grad = requires_grad
161161
self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device)

0 commit comments

Comments
 (0)