Skip to content

finetune OOM evalue lm_loss, metrics = forward_step(data_iterator, model, args, timers) #3

@sunshinejuanjuan123

Description

@sunshinejuanjuan123

run cmd in my data to finetune

bash finetune_single_identity.sh # Multi GPUs

in 8GPUS H100 eval will be OOM


from sat.training.deepspeed_training import training_main

    training_main(

        args,

        model_cls=SATVideoDiffusionEngine,

        forward_step_function=partial(forward_step, data_class=data_class),

        forward_step_eval=partial(

            forward_step_eval, data_class=data_class, only_log_video_latents=args.only_log_video_latents

        ),

        create_dataset_function=create_dataset_function,

    )

sat.training.deepspeed_training import training_main eval时显存会OOM如何修复解决
从 sat.training.deepspeed_training 导入 training_main

training_main(

args,

model_cls=SATVideoDiffusionEngine,

forward_step_function=partial(forward_step, data_class=data_class),

forward_step_eval=partial(

forward_step_eval, data_class=data_class, only_log_video_latents=args.only_log_video_latents

        ),

create_dataset_function=create_dataset_function,

    )

sat.training.deepspeed_training 导入 training_main 评估时显存会 OOM 如何修复解决


/mnt/afs//Concat-ID/train_video.py:74: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.                                    
  with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs):                                                                                                                                          
[rank0]: Traceback (most recent call last):                                                                                                                                                                      
[rank0]:   File "/mnt/afs//Concat-ID/train_video.py", line 235, in <module>                                                                                                                          
[rank0]:     training_main(                                                                                                                                                                                      
[rank0]:   File "/mnt/afs/miniconda/envs/concatid/lib/python3.11/site-packages/sat/training/deepspeed_training.py", line 158, in training_main                                                                   
[rank0]:     iteration, skipped = train(model, optimizer,                                                                                                                                                        
[rank0]:                          ^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                        
[rank0]:   File "/mnt/afs/miniconda/envs/concatid/lib/python3.11/site-packages/sat/training/deepspeed_training.py", line 414, in train                                                                           
[rank0]:     evaluate_and_print_results(                                                                                                                                                                         
[rank0]:   File "/mnt/afs/miniconda/envs/concatid/lib/python3.11/site-packages/sat/training/deepspeed_training.py", line 629, in evaluate_and_print_results                                                      
[rank0]:     lm_loss, metrics = evaluate(data_iterator, model, eval_iters, args, timers, split, verbose, has_last, hooks=hooks)                                                                                  
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                  
[rank0]:   File "/mnt/afs/miniconda/envs/concatid/lib/python3.11/site-packages/sat/training/deepspeed_training.py", line 571, in evaluate                                                                        
[rank0]:     lm_loss, metrics = forward_step(data_iterator, model, args, timers)                                                                                                                                 
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                 
[rank0]:   File "/mnt/afs//Concat-ID/train_video.py", line 173, in forward_step_eval                                                                                                                 
[rank0]:     log_video(batch_video, model, args, only_log_video_latents=only_log_video_latents)                                                                                                                  
[rank0]:   File "/mnt/afs//Concat-ID/train_video.py", line 75, in log_video                                                                                                                          
[rank0]:     videos = model.log_video(batch, only_log_video_latents=only_log_video_latents)                                                                                                                      
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                      
[rank0]:   File "/mnt/afs/miniconda/envs/concatid/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context                                                                        
[rank0]:     return func(*args, **kwargs)                                                                                                                                                                        
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                        
[rank0]:   File "/mnt/afs//Concat-ID/diffusion_video.py", line 334, in log_video                                                                                                                     
[rank0]:     log["reconstructions"] = self.decode_first_stage(z).to(torch.float32)                   

[rank0]:            ^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                               [45/1891]
[rank0]:   File "/mnt/afs//Concat-ID/diffusion_video.py", line 334, in log_video                                                                                                                     
[rank0]:     log["reconstructions"] = self.decode_first_stage(z).to(torch.float32)                                                                                                                               
[rank0]:                              ^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                 
[rank0]:   File "/mnt/afs/miniconda/envs/concatid/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context                                                                        
[rank0]:     return func(*args, **kwargs)                                                                                                                                                                        
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                        
[rank0]:   File "/mnt/afs//Concat-ID/diffusion_video.py", line 203, in decode_first_stage                                                                                                            
[rank0]:     out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], **kwargs)                                                                                                               
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                               
[rank0]:   File "/mnt/afs//Concat-ID/vae_modules/autoencoder.py", line 634, in decode                                                                                                                
[rank0]:     x = super().decode(z, **kwargs)                                                                                                                                                                     
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                     
[rank0]:   File "/mnt/afs//Concat-ID/vae_modules/autoencoder.py", line 233, in decode                                                                                                                
[rank0]:     x = self.decoder(z, **kwargs)                                                                                                                                                                       
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                       
[rank0]:   File "/mnt/afs/miniconda/envs/concatid/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl                                                                     
[rank0]:     return self._call_impl(*args, **kwargs)                                                                                                                                                             
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                             
[rank0]:   File "/mnt/afs/miniconda/envs/concatid/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl                                                                             
[rank0]:     return forward_call(*args, **kwargs)                                                                                                                                                                
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                
[rank0]:   File "/mnt/afs//Concat-ID/vae_modules/cp_enc_dec.py", line 970, in forward                                                                                                                
[rank0]:     h = self.up[i_level].block[i_block](h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)                                                                                                           
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                           
[rank0]:   File "/mnt/afs/miniconda/envs/concatid/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl                                                                     
[rank0]:     return self._call_impl(*args, **kwargs)                                                                                                                                                             
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                             
[rank0]:   File "/mnt/afs/miniconda/envs/concatid/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl                                                                             
[rank0]:     return forward_call(*args, **kwargs)                                                                                                                                                                
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                
[rank0]:   File "/mnt/afs//Concat-ID/vae_modules/cp_enc_dec.py", line 682, in forward                                                                                                                
[rank0]:     h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)                                                                                                                                      
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                      
[rank0]:   File "/mnt/afs/miniconda/envs/concatid/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl                                                                     
[rank0]:     return self._call_impl(*args, **kwargs)                                                                                                                                                             
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                             
[rank0]:   File "/mnt/afs/miniconda/envs/concatid/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl                                                                             
[rank0]:     return forward_call(*args, **kwargs)                                                                                                                                                                
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                
[rank0]:   File "/mnt/afs//Concat-ID/vae_modules/cp_enc_dec.py", line 509, in forward                                                                                                                
[rank0]:     new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)                                                                                                                                                  
[rank0]:             ~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~                                                                                                                                                  
[rank0]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 8.08 GiB. GPU 0 has a total capacity of 79.33 GiB of which 1.54 GiB is free. Process 293276 has 77.75 GiB memory in use. Of the allocated 
memory 74.27 GiB is allocated by PyTorch, and 108.13 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoi
d fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)  

[rank0]: File "/mnt/afs/miniconda/envs/concatid/lib/python3.11/site-packages/sat/training/deepspeed_training.py", line 571, in evaluate
[rank0]: lm_loss, metrics = forward_step(data_iterator, model, args, timers)
尝试在此处加torch.cuda.empty_cache() gc.collect() 仍然会出现OOM

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions