Skip to content

Commit 76bc5f0

Browse files
n1ck-guochensuyue
authored andcommitted
fix cuda ut bug of use_deterministic_algorithms (#805)
(cherry picked from commit 886d657)
1 parent 520c78f commit 76bc5f0

File tree

3 files changed

+38
-9
lines changed

3 files changed

+38
-9
lines changed

auto_round/autoround.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,8 @@ def __init__(
224224
to_quant_block_names: Union[str, list, None] = kwargs.pop("to_quant_block_names", None)
225225
enable_norm_bias_tuning: bool = kwargs.pop("enable_norm_bias_tuning", False)
226226
enable_quanted_input: bool = kwargs.pop("enable_quanted_input", True)
227-
disable_deterministic_algorithms = kwargs.pop("disable_deterministic_algorithms", False)
227+
disable_deterministic_algorithms = kwargs.pop("disable_deterministic_algorithms", True)
228+
enable_deterministic_algorithms = kwargs.pop("enable_deterministic_algorithms", False)
228229
static_kv_dtype = kwargs.pop("static_kv_dtype", None)
229230
device = kwargs.pop("device", None)
230231
self.quant_lm_head = kwargs.pop("quant_lm_head", False)
@@ -234,11 +235,19 @@ def __init__(
234235

235236
if kwargs:
236237
logger.warning(f"unrecognized keys {list(kwargs.keys())} were passed. Please check them.")
238+
if "CUBLAS_WORKSPACE_CONFIG" not in os.environ:
239+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
240+
# deprecated, default not to use torch.use_deterministic_algorithms
241+
if not disable_deterministic_algorithms or enable_deterministic_algorithms:
242+
if not disable_deterministic_algorithms:
243+
logger.warning(
244+
"default not use deterministic_algorithms. disable_deterministic_algorithms is deprecated,"
245+
" please use enable_deterministic_algorithms instead. "
246+
)
237247

238-
if not disable_deterministic_algorithms:
239-
if "CUBLAS_WORKSPACE_CONFIG" not in os.environ:
240-
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
241248
torch.use_deterministic_algorithms(True, warn_only=False)
249+
else:
250+
torch.use_deterministic_algorithms(True, warn_only=True)
242251

243252
if device is not None:
244253
logger.warning("`device` is deprecated, please use `device_map` instead")

auto_round/script/llm.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,12 @@ def __init__(self, *args, **kwargs):
206206
self.add_argument("--enable_alg_ext", action="store_true", help="whether to enable probably better algorithm")
207207

208208
self.add_argument(
209-
"--disable_deterministic_algorithms", action="store_true", help="disable torch deterministic algorithms."
209+
"--disable_deterministic_algorithms",
210+
action="store_true",
211+
help="deprecated, disable torch deterministic algorithms.",
212+
)
213+
self.add_argument(
214+
"--enable_deterministic_algorithms", action="store_true", help="enable torch deterministic algorithms."
210215
)
211216

212217
self.add_argument(
@@ -543,6 +548,11 @@ def tune(args):
543548
scheme = args.scheme.upper()
544549
if scheme not in PRESET_SCHEMES:
545550
raise ValueError(f"{scheme} is not supported. only {PRESET_SCHEMES.keys()} are supported ")
551+
if args.disable_deterministic_algorithms:
552+
logger.warning(
553+
"default not use deterministic_algorithms. disable_deterministic_algorithms is deprecated,"
554+
" please use enable_deterministic_algorithms instead. "
555+
)
546556
autoround = round(
547557
model=model,
548558
tokenizer=tokenizer,
@@ -580,7 +590,7 @@ def tune(args):
580590
super_group_size=args.super_group_size,
581591
super_bits=args.super_bits,
582592
disable_opt_rtn=args.disable_opt_rtn,
583-
disable_deterministic_algorithms=args.disable_deterministic_algorithms,
593+
enable_deterministic_algorithms=args.enable_deterministic_algorithms,
584594
enable_alg_ext=args.enable_alg_ext,
585595
**mllm_kwargs,
586596
)

auto_round/script/mllm.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,13 @@ def __init__(self, *args, **kwargs):
173173
self.add_argument("--enable_torch_compile", action="store_true", help="whether to enable torch compile")
174174

175175
self.add_argument(
176-
"--disable_deterministic_algorithms", action="store_true", help="disable torch deterministic algorithms."
176+
"--disable_deterministic_algorithms",
177+
action="store_true",
178+
help="deprecated, disable torch deterministic algorithms.",
179+
)
180+
181+
self.add_argument(
182+
"--enable_deterministic_algorithms", action="store_true", help="enable torch deterministic algorithms."
177183
)
178184

179185
## ======================= VLM =======================
@@ -435,7 +441,11 @@ def tune(args):
435441
scheme = args.scheme.upper()
436442
if scheme not in PRESET_SCHEMES:
437443
raise ValueError(f"{scheme} is not supported. only {PRESET_SCHEMES.keys()} are supported ")
438-
444+
if args.disable_deterministic_algorithms:
445+
logger.warning(
446+
"default not use deterministic_algorithms. disable_deterministic_algorithms is deprecated,"
447+
" please use enable_deterministic_algorithms instead. "
448+
)
439449
autoround = round(
440450
model,
441451
tokenizer,
@@ -473,7 +483,7 @@ def tune(args):
473483
model_kwargs=model_kwargs,
474484
data_type=args.data_type,
475485
disable_opt_rtn=args.disable_opt_rtn,
476-
disable_deterministic_algorithms=args.disable_deterministic_algorithms,
486+
enable_deterministic_algorithms=args.enable_deterministic_algorithms,
477487
)
478488

479489
model_name = args.model.rstrip("/")

0 commit comments

Comments
 (0)