-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Description
System Info
Python version: 3.12.9
transformers version: 5.2.0
bitsandbytes version: 0.49.2
peft version: 0.18.1
torch version: 2.10.0
trl version: 0.29.0
accelerate version: 1.12.0
deepspeed version: 0.18.6
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Using quantized and LoRA models with the DPOTrainer and accelerate in a multi-gpu setting (with deepspeed ZeRO stage 3) I get the following error only if gradient_checkpointing=True :
torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
I do not get this error when not using LoRA and quantization (via peft and bitsandbytes).
A related issue is here. As this comment mentions, setting use_reentrant=True works, but appears to affect convergence.
Reproducible Example (Full Stack Trace below):
from trl import DPOConfig, DPOTrainer
from datasets import load_dataset
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
from transformers import BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
import torch
training_config = DPOConfig(
gradient_checkpointing=True,
max_length=None,
per_device_train_batch_size=1,
eval_strategy="no",
report_to="tensorboard"
)
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
peft_config = LoraConfig(
r=8,
lora_alpha=8,
lora_dropout=0.1,
target_modules=["down_proj", "o_proj", "k_proj", "q_proj", "gate_proj", "up_proj", "v_proj"],
use_dora=True,
init_lora_weights="gaussian",
)
model_name = "Qwen/Qwen3-VL-2B-Instruct"
model = Qwen3VLForConditionalGeneration.from_pretrained(
model_name,
quantization_config=quantization_config,
)
processor = AutoProcessor.from_pretrained(model_name)
peft_model = get_peft_model(model, peft_config)
trainer = DPOTrainer(
model=peft_model,
args=training_config,
train_dataset=load_dataset("HuggingFaceH4/rlaif-v_formatted", split="train[:5%]"),
)
trainer.train()Setting gradient_checkpointing=False above works, but is expensive.
deepspeed configuration
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: falseFull Stack Trace
accelerate launch --config_file ds_mre_peft_bnb.yaml trl_mre_peft_bnb.py
W0304 18:46:12.454000 7431 torch/distributed/run.py:852]
W0304 18:46:12.454000 7431 torch/distributed/run.py:852] *****************************************
W0304 18:46:12.454000 7431 torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0304 18:46:12.454000 7431 torch/distributed/run.py:852] *****************************************
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
Loading weights: 98%|▉| 611/625 [00:00<00:00, 1399.29it/s, Materializing param=model.visual.deepstack_merger_list.2.linear_fc1.b
Loading weights: 100%|█████████████████████| 625/625 [00:00<00:00, 763.40it/s, Materializing param=model.visual.pos_embed.weight]
Loading weights: 100%|█████████████████████| 625/625 [00:00<00:00, 726.70it/s, Materializing param=model.visual.pos_embed.weight]
Loading weights: 100%|█████████████████████| 625/625 [00:00<00:00, 731.66it/s, Materializing param=model.visual.pos_embed.weight]
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 151645, 'bos_token_id': None, 'pad_token_id': 151643}.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 151645, 'bos_token_id': None, 'pad_token_id': 151643}.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 151645, 'bos_token_id': None, 'pad_token_id': 151643}.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 151645, 'bos_token_id': None, 'pad_token_id': 151643}.
Installed CUDA version 12.6 does not match the version torch was compiled with 12.8 but since the APIs are compatible, accepting this combination
Installed CUDA version 12.6 does not match the version torch was compiled with 12.8 but since the APIs are compatible, accepting this combination
Installed CUDA version 12.6 does not match the version torch was compiled with 12.8 but since the APIs are compatible, accepting this combination
Installed CUDA version 12.6 does not match the version torch was compiled with 12.8 but since the APIs are compatible, accepting this combination
Stage 3 initialize beginning
MA 1.51 GB Max_MA 3.7 GB CA 2.33 GB Max_CA 4 GB
CPU Virtual Memory: used = 13.89 GB, percent = 7.6%
DeepSpeedZeRoOffload initialize [begin]
MA 1.51 GB Max_MA 1.51 GB CA 2.33 GB Max_CA 2 GB
CPU Virtual Memory: used = 13.99 GB, percent = 7.7%
Parameter Offload - Persistent parameters statistics: param_count = 1498, numel = 19075072
DeepSpeedZeRoOffload initialize [end]
MA 0.04 GB Max_MA 1.51 GB CA 2.33 GB Max_CA 2 GB
CPU Virtual Memory: used = 15.76 GB, percent = 8.7%
Before creating fp16 partitions
MA 0.04 GB Max_MA 0.04 GB CA 2.33 GB Max_CA 2 GB
CPU Virtual Memory: used = 15.77 GB, percent = 8.7%
After creating fp16 partitions: 1
MA 0.04 GB Max_MA 0.04 GB CA 2.33 GB Max_CA 2 GB
CPU Virtual Memory: used = 15.86 GB, percent = 8.7%
Before creating fp32 partitions
MA 0.04 GB Max_MA 0.04 GB CA 2.33 GB Max_CA 2 GB
CPU Virtual Memory: used = 15.86 GB, percent = 8.7%
After creating fp32 partitions
MA 0.04 GB Max_MA 0.04 GB CA 2.33 GB Max_CA 2 GB
CPU Virtual Memory: used = 15.88 GB, percent = 8.7%
Before initializing optimizer states
MA 0.04 GB Max_MA 0.04 GB CA 2.33 GB Max_CA 2 GB
CPU Virtual Memory: used = 15.93 GB, percent = 8.8%
After initializing optimizer states
MA 0.04 GB Max_MA 0.04 GB CA 2.33 GB Max_CA 2 GB
CPU Virtual Memory: used = 15.95 GB, percent = 8.8%
After initializing ZeRO optimizer
MA 0.97 GB Max_MA 0.97 GB CA 2.33 GB Max_CA 2 GB
CPU Virtual Memory: used = 16.14 GB, percent = 8.9%
0%| | 0/2964 [00:00<?, ?it/s]^A[rank0]: Traceback (most recent call last):
[rank0]: File "/home/sagemaker-user/trl_mre_peft_bnb.py", line 47, in <module>
[rank0]: trainer.train()
[rank0]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/transformers/trainer.py", line 1412, in train
[rank0]: return inner_training_loop(
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/transformers/trainer.py", line 1742, in _inner_training_loop
[rank0]: tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/trl/trainer/dpo_trainer.py", line 1422, in training_step
[rank0]: return super().training_step(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/transformers/trainer.py", line 1979, in training_step
[rank0]: self.accelerator.backward(loss, **kwargs)
[rank0]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/accelerate/accelerator.py", line 2844, in backward
[rank0]: self.deepspeed_engine_wrapped.backward(loss, sync_gradients=self.sync_gradients, **kwargs)
[rank0]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/accelerate/utils/deepspeed.py", line 270, in backward
[rank0]: self.engine.backward(loss, **kwargs)
[rank0]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank0]: ret_val = func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 2583, in backward
[rank0]: loss.backward(**backward_kwargs)
[rank0]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/_tensor.py", line 630, in backward
[rank0]: torch.autograd.backward(
[rank0]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/autograd/__init__.py", line 364, in backward
[rank0]: _engine_run_backward(
[rank0]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
[rank0]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 1177, in unpack_hook
[rank0]: frame.check_recomputed_tensors_match(gid)
[rank0]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 921, in check_recomputed_tensors_match
[rank0]: raise CheckpointError(
[rank0]: torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
[rank0]: tensor at position 4:
[rank0]: saved metadata: {'shape': torch.Size([2048]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
[rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
[rank0]: tensor at position 22:
[rank0]: saved metadata: {'shape': torch.Size([128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
[rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
[rank0]: tensor at position 40:
[rank0]: saved metadata: {'shape': torch.Size([128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
[rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
[rank0]: tensor at position 82:
[rank0]: saved metadata: {'shape': torch.Size([2048]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
[rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
[rank0]: .
[rank0]: Tip: To see a more detailed error message, either pass `debug=True` to
[rank0]: `torch.utils.checkpoint.checkpoint(...)` or wrap the code block
[rank0]: with `with torch.utils.checkpoint.set_checkpoint_debug_enabled(True):` to
[rank0]: enable checkpoint‑debug mode globally.
[rank1]: Traceback (most recent call last):
[rank1]: File "/home/sagemaker-user/trl_mre_peft_bnb.py", line 47, in <module>
[rank1]: trainer.train()
[rank1]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/transformers/trainer.py", line 1412, in train
[rank1]: return inner_training_loop(
[rank1]: ^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/transformers/trainer.py", line 1742, in _inner_training_loop
[rank1]: tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/trl/trainer/dpo_trainer.py", line 1422, in training_step
[rank1]: return super().training_step(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/transformers/trainer.py", line 1979, in training_step
[rank1]: self.accelerator.backward(loss, **kwargs)
[rank1]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/accelerate/accelerator.py", line 2844, in backward
[rank1]: self.deepspeed_engine_wrapped.backward(loss, sync_gradients=self.sync_gradients, **kwargs)
[rank1]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/accelerate/utils/deepspeed.py", line 270, in backward
[rank1]: self.engine.backward(loss, **kwargs)
[rank1]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank1]: ret_val = func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 2583, in backward
[rank1]: loss.backward(**backward_kwargs)
[rank1]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/_tensor.py", line 630, in backward
[rank1]: torch.autograd.backward(
[rank1]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/autograd/__init__.py", line 364, in backward
[rank1]: _engine_run_backward(
[rank1]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
[rank1]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 1177, in unpack_hook
[rank1]: frame.check_recomputed_tensors_match(gid)
[rank1]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 921, in check_recomputed_tensors_match
[rank1]: raise CheckpointError(
[rank1]: torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
[rank1]: tensor at position 4:
[rank1]: saved metadata: {'shape': torch.Size([2048]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}
[rank1]: tensor at position 22:
[rank1]: saved metadata: {'shape': torch.Size([128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}
[rank1]: tensor at position 40:
[rank1]: saved metadata: {'shape': torch.Size([128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}
[rank1]: tensor at position 83:
[rank1]: saved metadata: {'shape': torch.Size([2048]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}
[rank1]: .
[rank1]: Tip: To see a more detailed error message, either pass `debug=True` to
[rank1]: `torch.utils.checkpoint.checkpoint(...)` or wrap the code block
[rank1]: with `with torch.utils.checkpoint.set_checkpoint_debug_enabled(True):` to
[rank1]: enable checkpoint‑debug mode globally.
[rank3]: Traceback (most recent call last):
[rank3]: File "/home/sagemaker-user/trl_mre_peft_bnb.py", line 47, in <module>
[rank3]: trainer.train()
[rank3]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/transformers/trainer.py", line 1412, in train
[rank3]: return inner_training_loop(
[rank3]: ^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/transformers/trainer.py", line 1742, in _inner_training_loop
[rank3]: tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/trl/trainer/dpo_trainer.py", line 1422, in training_step
[rank3]: return super().training_step(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/transformers/trainer.py", line 1979, in training_step
[rank3]: self.accelerator.backward(loss, **kwargs)
[rank3]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/accelerate/accelerator.py", line 2844, in backward
[rank3]: self.deepspeed_engine_wrapped.backward(loss, sync_gradients=self.sync_gradients, **kwargs)
[rank3]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/accelerate/utils/deepspeed.py", line 270, in backward
[rank3]: self.engine.backward(loss, **kwargs)
[rank3]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank3]: ret_val = func(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 2583, in backward
[rank3]: loss.backward(**backward_kwargs)
[rank3]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/_tensor.py", line 630, in backward
[rank3]: torch.autograd.backward(
[rank3]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/autograd/__init__.py", line 364, in backward
[rank3]: _engine_run_backward(
[rank3]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
[rank3]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 1177, in unpack_hook
[rank3]: frame.check_recomputed_tensors_match(gid)
[rank3]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 921, in check_recomputed_tensors_match
[rank3]: raise CheckpointError(
[rank3]: torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
[rank3]: tensor at position 4:
[rank3]: saved metadata: {'shape': torch.Size([2048]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=3)}
[rank3]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=3)}
[rank3]: tensor at position 22:
[rank3]: saved metadata: {'shape': torch.Size([128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=3)}
[rank3]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=3)}
[rank3]: tensor at position 40:
[rank3]: saved metadata: {'shape': torch.Size([128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=3)}
[rank3]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=3)}
[rank3]: tensor at position 83:
[rank3]: saved metadata: {'shape': torch.Size([2048]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=3)}
[rank3]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=3)}
[rank3]: .
[rank3]: Tip: To see a more detailed error message, either pass `debug=True` to
[rank3]: `torch.utils.checkpoint.checkpoint(...)` or wrap the code block
[rank3]: with `with torch.utils.checkpoint.set_checkpoint_debug_enabled(True):` to
[rank3]: enable checkpoint‑debug mode globally.
[rank2]: Traceback (most recent call last):
[rank2]: File "/home/sagemaker-user/trl_mre_peft_bnb.py", line 47, in <module>
[rank2]: trainer.train()
[rank2]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/transformers/trainer.py", line 1412, in train
[rank2]: return inner_training_loop(
[rank2]: ^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/transformers/trainer.py", line 1742, in _inner_training_loop
[rank2]: tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/trl/trainer/dpo_trainer.py", line 1422, in training_step
[rank2]: return super().training_step(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/transformers/trainer.py", line 1979, in training_step
[rank2]: self.accelerator.backward(loss, **kwargs)
[rank2]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/accelerate/accelerator.py", line 2844, in backward
[rank2]: self.deepspeed_engine_wrapped.backward(loss, sync_gradients=self.sync_gradients, **kwargs)
[rank2]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/accelerate/utils/deepspeed.py", line 270, in backward
[rank2]: self.engine.backward(loss, **kwargs)
[rank2]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank2]: ret_val = func(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 2583, in backward
[rank2]: loss.backward(**backward_kwargs)
[rank2]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/_tensor.py", line 630, in backward
[rank2]: torch.autograd.backward(
[rank2]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/autograd/__init__.py", line 364, in backward
[rank2]: _engine_run_backward(
[rank2]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
[rank2]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 1177, in unpack_hook
[rank2]: frame.check_recomputed_tensors_match(gid)
[rank2]: File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 921, in check_recomputed_tensors_match
[rank2]: raise CheckpointError(
[rank2]: torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
[rank2]: tensor at position 4:
[rank2]: saved metadata: {'shape': torch.Size([2048]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=2)}
[rank2]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=2)}
[rank2]: tensor at position 22:
[rank2]: saved metadata: {'shape': torch.Size([128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=2)}
[rank2]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=2)}
[rank2]: tensor at position 40:
[rank2]: saved metadata: {'shape': torch.Size([128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=2)}
[rank2]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=2)}
[rank2]: tensor at position 83:
[rank2]: saved metadata: {'shape': torch.Size([2048]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=2)}
[rank2]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=2)}
[rank2]: .
[rank2]: Tip: To see a more detailed error message, either pass `debug=True` to
[rank2]: `torch.utils.checkpoint.checkpoint(...)` or wrap the code block
[rank2]: with `with torch.utils.checkpoint.set_checkpoint_debug_enabled(True):` to
[rank2]: enable checkpoint‑debug mode globally.
0%| | 0/2964 [00:07<?, ?it/s]
W0304 18:46:46.408000 7431 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 7561 closing signal SIGTERM
W0304 18:46:46.409000 7431 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 7563 closing signal SIGTERM
W0304 18:46:46.410000 7431 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 7564 closing signal SIGTERM
E0304 18:46:46.775000 7431 torch/distributed/elastic/multiprocessing/api.py:984] failed (exitcode: 1) local_rank: 1 (pid: 7562) of binary: /home/sagemaker-user/.env/bin/python3
Traceback (most recent call last):
File "/home/sagemaker-user/.env/bin/accelerate", line 10, in <module>
sys.exit(main())
^^^^^^
File "/home/sagemaker-user/.env/lib/python3.12/site-packages/accelerate/commands/accelerate_cli.py", line 50, in main
args.func(args)
File "/home/sagemaker-user/.env/lib/python3.12/site-packages/accelerate/commands/launch.py", line 1266, in launch_command
deepspeed_launcher(args)
File "/home/sagemaker-user/.env/lib/python3.12/site-packages/accelerate/commands/launch.py", line 952, in deepspeed_launcher
distrib_run.run(args)
File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/distributed/run.py", line 982, in run
elastic_launch(
File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 170, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/sagemaker-user/.env/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 317, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
trl_mre_peft_bnb.py FAILED
------------------------------------------------------------
Failures:
[1]:
time : 2026-03-04_18:46:46
host : default
rank : 0 (local_rank: 0)
exitcode : -15 (pid: 7561)
error_file: <N/A>
traceback : Signal 15 (SIGTERM) received by PID 7561
[2]:
time : 2026-03-04_18:46:46
host : default
rank : 2 (local_rank: 2)
exitcode : 1 (pid: 7563)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[3]:
time : 2026-03-04_18:46:46
host : default
rank : 3 (local_rank: 3)
exitcode : 1 (pid: 7564)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2026-03-04_18:46:46
host : default
rank : 1 (local_rank: 1)
exitcode : 1 (pid: 7562)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
Expected behavior
The expected behavior is for the DPOTrainer's train() method to run without errors, unless those errors are my own.