@@ -157,7 +157,7 @@ def __init__(
157157 bias : bool = True ,
158158 device = None ,
159159 dtype = None ,
160- vectorize : bool = False ,
160+ vector_wise_quantization : bool = False ,
161161 mem_efficient : bool = False ,
162162 ):
163163 super ().__init__ (in_features , out_features , bias , device , dtype )
@@ -167,11 +167,11 @@ def __init__(
167167 Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''' )
168168
169169 # By default, we use the global quantization.
170- self .vectorize = vectorize
171- if self .vectorize :
170+ self .vector_wise_quantization = vector_wise_quantization
171+ if self .vector_wise_quantization :
172172 self ._fn = _switchback_vectorrize
173173 if mem_efficient :
174- print ('mem efficient is not supported for vectorize mode .' )
174+ print ('mem efficient is not supported for vector-wise quantization .' )
175175 exit (1 )
176176 else :
177177 if mem_efficient :
@@ -188,7 +188,7 @@ def prepare_for_eval(self):
188188 # m.prepare_for_eval()
189189 # model.apply(cond_prepare)
190190 print ('=> preparing for eval.' )
191- if self .vectorize :
191+ if self .vector_wise_quantization :
192192 W_int8 , state_W = quantize_rowwise (self .weight )
193193 else :
194194 W_int8 , state_W = quantize_global (self .weight )
@@ -210,7 +210,7 @@ def forward(self, x):
210210 X = x .view (- 1 , x .size (- 1 ))
211211 X_int8 , state_X = quantize_rowwise (X )
212212
213- if self .vectorize :
213+ if self .vector_wise_quantization :
214214 return int8_matmul_rowwise_dequantize (
215215 X_int8 , self .W_int8 .t (), state_X , self .state_W , self .bias
216216 ).view (* x .size ()[:- 1 ], - 1 )
@@ -219,9 +219,9 @@ def forward(self, x):
219219 X_int8 , self .W_int8 .t (), state_X , self .state_W , self .bias
220220 ).view (* x .size ()[:- 1 ], - 1 )
221221
222- SwitchBackLinearGlobal = partial (SwitchBackLinear , vectorize = False )
223- SwitchBackLinearGlobalMemEfficient = partial (SwitchBackLinear , vectorize = False , mem_efficient = True )
224- SwitchBackLinearVectorized = partial (SwitchBackLinear , vectorize = True )
222+ SwitchBackLinearGlobal = partial (SwitchBackLinear , vector_wise_quantization = False )
223+ SwitchBackLinearGlobalMemEfficient = partial (SwitchBackLinear , vector_wise_quantization = False , mem_efficient = True )
224+ SwitchBackLinearVectorwise = partial (SwitchBackLinear , vector_wise_quantization = True )
225225
226226# This is just the standard linear function.
227227class StandardLinearFunction (torch .autograd .Function ):
0 commit comments