Skip to content

Commit e7452d0

Browse files
committed
Merge pull request #1865 from bghira/chore/1860
(#1860) configure Dynamo for BnB models dynamic outputs
2 parents ff2eddc + a31dce1 commit e7452d0

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

simpletuner/helpers/training/trainer.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from typing import Any, Dict, List, Optional
1919

2020
import huggingface_hub
21-
import wandb
2221

22+
import wandb
2323
from simpletuner.helpers import log_format # noqa
2424
from simpletuner.helpers.caching.memory import reclaim_memory
2525
from simpletuner.helpers.configuration.cli_utils import mapping_to_cli_args
@@ -49,6 +49,7 @@
4949
create_optimizer_with_param_groups,
5050
determine_optimizer_class_with_config,
5151
determine_params_to_optimize,
52+
is_bitsandbytes_available,
5253
is_lr_schedulefree,
5354
is_lr_scheduler_disabled,
5455
)
@@ -238,6 +239,49 @@ def _update_grad_metrics(
238239
) and not self.config.use_deepspeed_optimizer:
239240
target_logs["grad_absmax"] = self.grad_norm
240241

242+
def _config_uses_bitsandbytes(self) -> bool:
243+
if not getattr(self, "config", None):
244+
return False
245+
246+
def _contains_bnb(value: object) -> bool:
247+
if isinstance(value, str):
248+
return "bnb" in value.lower()
249+
if isinstance(value, dict):
250+
return any(_contains_bnb(item) for item in value.values())
251+
if isinstance(value, (list, tuple, set)):
252+
return any(_contains_bnb(item) for item in value)
253+
return False
254+
255+
for attr_value in vars(self.config).values():
256+
try:
257+
if _contains_bnb(attr_value):
258+
return True
259+
except Exception:
260+
continue
261+
return False
262+
263+
def _enable_dynamo_dynamic_output_capture(self) -> None:
264+
try:
265+
import torch._dynamo as torch_dynamo
266+
except Exception as exc:
267+
logger.warning("Unable to configure Torch Dynamo dynamic output capture: %s", exc)
268+
return
269+
270+
config_obj = getattr(torch_dynamo, "config", None)
271+
if config_obj is None:
272+
logger.debug("Torch Dynamo config unavailable; skipping dynamic output capture configuration.")
273+
return
274+
if not hasattr(config_obj, "capture_dynamic_output_shape_ops"):
275+
logger.debug(
276+
"Torch Dynamo config lacks capture_dynamic_output_shape_ops; skipping dynamic output capture configuration."
277+
)
278+
return
279+
if getattr(config_obj, "capture_dynamic_output_shape_ops", False):
280+
return
281+
282+
config_obj.capture_dynamic_output_shape_ops = True
283+
logger.info("Torch Dynamo capture_dynamic_output_shape_ops enabled for bitsandbytes models.")
284+
241285
def parse_arguments(self, args=None, disable_accelerator: bool = False, exit_on_error: bool = False):
242286
skip_config_fallback = False
243287
args_payload = args
@@ -402,6 +446,9 @@ def _coerce_flag(value: object) -> bool:
402446

403447
dynamo_plugin = None
404448
if resolved_dynamo_backend and resolved_dynamo_backend != DynamoBackend.NO:
449+
if is_bitsandbytes_available and self._config_uses_bitsandbytes():
450+
self._enable_dynamo_dynamic_output_capture()
451+
405452
plugin_kwargs: Dict[str, object] = {"backend": resolved_dynamo_backend}
406453

407454
mode_value = getattr(self.config, "dynamo_mode", None)

simpletuner/helpers/training/validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
import diffusers
1111
import numpy as np
1212
import torch
13-
import wandb
1413
from tqdm import tqdm
1514

15+
import wandb
1616
from simpletuner.helpers.models.common import ImageModelFoundation, ModelFoundation, VideoModelFoundation
1717
from simpletuner.helpers.training.wrappers import unwrap_model
1818

0 commit comments

Comments
 (0)