22#
33# This source code is licensed under the MIT license found in the
44# LICENSE file in the root directory of this source tree.
5- from typing import Optional , TypeVar , Union , overload
5+ from typing import Any , Dict , Optional , TypeVar , Union , overload
66
77import warnings
88import torch
@@ -142,7 +142,7 @@ def forward(self, input: Tensor) -> Tensor:
142142
143143class Params4bit (torch .nn .Parameter ):
144144
145- def __new__ (cls , data = None , requires_grad = True , quant_state = None , blocksize = 64 , compress_statistics = True , quant_type = 'fp4' ):
145+ def __new__ (cls , data : Optional [ torch . Tensor ] = None , requires_grad = True , quant_state : QuantState = None , blocksize : int = 64 , compress_statistics : bool = True , quant_type : str = 'fp4' ) -> "Params4bit" :
146146 if data is None :
147147 data = torch .empty (0 )
148148
@@ -155,7 +155,7 @@ def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64,
155155 return self
156156
157157 @classmethod
158- def from_prequantized (cls , data , quantized_stats , requires_grad = False , device = 'cuda' , ** kwargs ):
158+ def from_prequantized (cls , data : torch . Tensor , quantized_stats : Dict [ str , Any ], requires_grad : bool = False , device = 'cuda' , ** kwargs ) -> "Params4bit" :
159159 self = torch .Tensor ._make_subclass (cls , data .to (device ))
160160 self .requires_grad = requires_grad
161161 self .quant_state = QuantState .from_dict (qs_dict = quantized_stats , device = device )
0 commit comments