Skip to content

ValueError when using peft on FSDPTrainer #90

@AragornHorse

Description

@AragornHorse

ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32

I tried running this code on two 80GB A100 and added PEFT's Lora in train.py:

peft_config = LoraConfig(                                                                                                                  
     r=config.lora.r, 
    lora_alpha=config.lora.alpha,
    lora_dropout=config.lora.dropout                                                    
)

policy = get_peft_model(policy, peft_config)

It can train normally when using BasicTrainer, however when using FSDPTrainer, I met:

Traceback (most recent call last):
  File "/direct-preference-optimization/train.py", line 127, in main
    mp.spawn(worker_main, nprocs=world_size, args=(world_size, config, policy, reference_model), join=True)
  File "python3.9/site-packages/torch/multiprocessing/spawn.py", line 282, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
  File "python3.9/site-packages/torch/multiprocessing/spawn.py", line 238, in start_processes
    while not context.join():
  File "python3.9/site-packages/torch/multiprocessing/spawn.py", line 189, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "python3.9/site-packages/torch/multiprocessing/spawn.py", line 76, in _wrap
    fn(i, *args)
  File "direct-preference-optimization/train.py", line 43, in worker_main
    trainer = TrainerClass(policy, config, config.seed, config.local_run_dir, reference_model=reference_model, rank=rank, world_size=world_size)
  File "direct-preference-optimization/trainers.py", line 469, in __init__
    self.policy = FSDP(policy, **shared_fsdp_kwargs, mixed_precision=policy_mp_policy)
  File "python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 483, in __init__
    _auto_wrap(
  File "python3.9/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 102, in _auto_wrap
    _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
  File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  [Previous line repeated 2 more times]
  File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 562, in _recursive_wrap
    return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
  File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 491, in _wrap
    return wrapper_cls(module, **kwargs)
  File "python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 509, in __init__
    _init_param_handle_from_module(
  File "python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 603, in _init_param_handle_from_module
    _init_param_handle_from_params(state, managed_params, fully_sharded_module)
  File "python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 615, in _init_param_handle_from_params
    handle = FlatParamHandle(
  File "python3.9/site-packages/torch/distributed/fsdp/_flat_param.py", line 583, in __init__
    self._init_flat_param_and_metadata(
  File "python3.9/site-packages/torch/distributed/fsdp/_flat_param.py", line 633, in _init_flat_param_and_metadata
    ) = self._validate_tensors_to_flatten(params)
  File "python3.9/site-packages/torch/distributed/fsdp/_flat_param.py", line 771, in _validate_tensors_to_flatten
    raise ValueError(
ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32

I tried to use bfloat16 in Lora modules, but other ValueErrors occurs.
I tried use_orig_params=True, it doesn't work.

How to solve it?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions