@@ -224,7 +224,8 @@ def __init__(
224
224
to_quant_block_names : Union [str , list , None ] = kwargs .pop ("to_quant_block_names" , None )
225
225
enable_norm_bias_tuning : bool = kwargs .pop ("enable_norm_bias_tuning" , False )
226
226
enable_quanted_input : bool = kwargs .pop ("enable_quanted_input" , True )
227
- disable_deterministic_algorithms = kwargs .pop ("disable_deterministic_algorithms" , False )
227
+ disable_deterministic_algorithms = kwargs .pop ("disable_deterministic_algorithms" , True )
228
+ enable_deterministic_algorithms = kwargs .pop ("enable_deterministic_algorithms" , False )
228
229
static_kv_dtype = kwargs .pop ("static_kv_dtype" , None )
229
230
device = kwargs .pop ("device" , None )
230
231
self .quant_lm_head = kwargs .pop ("quant_lm_head" , False )
@@ -234,11 +235,19 @@ def __init__(
234
235
235
236
if kwargs :
236
237
logger .warning (f"unrecognized keys { list (kwargs .keys ())} were passed. Please check them." )
238
+ if "CUBLAS_WORKSPACE_CONFIG" not in os .environ :
239
+ os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = ":4096:8"
240
+ # deprecated, default not to use torch.use_deterministic_algorithms
241
+ if not disable_deterministic_algorithms or enable_deterministic_algorithms :
242
+ if not disable_deterministic_algorithms :
243
+ logger .warning (
244
+ "default not use deterministic_algorithms. disable_deterministic_algorithms is deprecated,"
245
+ " please use enable_deterministic_algorithms instead. "
246
+ )
237
247
238
- if not disable_deterministic_algorithms :
239
- if "CUBLAS_WORKSPACE_CONFIG" not in os .environ :
240
- os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = ":4096:8"
241
248
torch .use_deterministic_algorithms (True , warn_only = False )
249
+ else :
250
+ torch .use_deterministic_algorithms (True , warn_only = True )
242
251
243
252
if device is not None :
244
253
logger .warning ("`device` is deprecated, please use `device_map` instead" )
0 commit comments