Skip to content

RuntimeError: The size of tensor a (91) must match the size of tensor b (90) at non-singleton dimension 2Β #83

@erikqu

Description

@erikqu

It seems like all my batches have some underlying issue where they're all off by one, I've seen other issues opened about this, but no proper explanation, could I get some help on this?

Failed during forward The size of tensor a (91) must match the size of tensor b (90) at non-singleton dimension 2

Verified text and wavs are both the batch size (16), all wavs are padded in this case to 84480.

RuntimeError: The size of tensor a (91) must match the size of tensor b (90) at non-singleton dimension 2                                                                                                                                                                                
Failed during forward The size of tensor a (91) must match the size of tensor b (90) at non-singleton dimension 2                                                                                                                                                                        
The size of tensor a (91) must match the size of tensor b (90) at non-singleton dimension 2                                                                                                                                                                                              
torch.Size([16, 84480]) 16                                                                                                                  
Traceback (most recent call last):                                                                                                          
  File "/mnt/nvme/programs/qTTS/train_ttv_v1.py", line 184, in train_and_evaluate                                                           
    loss_gen_all = net_g(                                                                                                                   
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1523, in forward                                                                                                                                                                                 
    else self._run_ddp_forward(*inputs, **kwargs)                                                                                           
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward                                                                                                                                                                        
    return self.module(*inputs, **kwargs)  # type: ignore[index]                                                                            
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/audio_diffusion_pytorch/models.py", line 40, in forward                                                                                                                                                                                  
    return self.diffusion(*args, **kwargs)                                                                                                  
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/audio_diffusion_pytorch/diffusion.py", line 93, in forward                                                                                                                                                                               
    v_pred = self.net(x_noisy, sigmas, **kwargs)                                                                                            
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 63, in forward                                                      
    return forward_fn(*args, **kwargs)                                                                                                      
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 594, in forward                                                                                                                                                                                                  
    return net(x, features=features, **kwargs)                                                                                              
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl           
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                                                                                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 63, in forward                                                      
    return forward_fn(*args, **kwargs)                                                                                                      
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 621, in forward                                                     
    return net(x, embedding=text_embedding, **kwargs)                                                                                       
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 63, in forward                                                                                                                                                                                                   
    return forward_fn(*args, **kwargs)                                                                                                      
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 552, in forward                                                                                                                                                                                                  
    return net(x, embedding=embedding, **kwargs)                                                                                            
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/a_unet/apex.py", line 431, in forward                                                                                                                                                                                                    
    return self.net(x, features, embedding, channels)  # type: ignore                                                                       
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/a_unet/apex.py", line 382, in forward                                                                                                                                                                                                    
    x = self.block(x, features, embedding, channels)                                                                                        
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 77, in forward                                                      
    x = block(x, *args)                                                                                                                     
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                              

Followed the example from the README:

  net_g = DiffusionModel(
      net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
      in_channels=1, # U-Net: number of input/output (audio) channels
      channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
      factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
      items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
      attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
      attention_heads=8, # U-Net: number of attention heads per attention item
      attention_features=64, # U-Net: number of attention features per attention item
      diffusion_t=VDiffusion, # The diffusion method used
      sampler_t=VSampler, # The diffusion sampler used
      use_text_conditioning=True, # U-Net: enables text conditioning (default T5-base)
      use_embedding_cfg=True, # U-Net: enables classifier free guidance
      embedding_max_length=64, # U-Net: text embedding maximum length (default for T5-base)
      embedding_features=768, # U-Net: text mbedding features (default for T5-base)
      cross_attentions=[0, 0, 0, 1, 1, 1, 1, 1, 1], # U-Net: cross-attention enabled/disabled at each layer
  )

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions