Skip to content

[BUG] Device error when running on other cuda device than cuda:0  #215

@cornzz

Description

@cornzz

Python -VV

Python 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0]

Pip Freeze

conda env export
name: test
channels:
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - asttokens=2.0.5=pyhd3eb1b0_0
  - bzip2=1.0.8=h5eee18b_6
  - ca-certificates=2024.7.2=h06a4308_0
  - comm=0.2.1=py311h06a4308_0
  - debugpy=1.6.7=py311h6a678d5_0
  - decorator=5.1.1=pyhd3eb1b0_0
  - executing=0.8.3=pyhd3eb1b0_0
  - ipykernel=6.28.0=py311h06a4308_0
  - ipython=8.25.0=py311h06a4308_0
  - jedi=0.19.1=py311h06a4308_0
  - jupyter_client=8.6.0=py311h06a4308_0
  - jupyter_core=5.7.2=py311h06a4308_0
  - ld_impl_linux-64=2.38=h1181459_1
  - libffi=3.4.4=h6a678d5_1
  - libgcc-ng=11.2.0=h1234567_1
  - libgomp=11.2.0=h1234567_1
  - libsodium=1.0.18=h7b6447c_0
  - libstdcxx-ng=11.2.0=h1234567_1
  - libuuid=1.41.5=h5eee18b_0
  - matplotlib-inline=0.1.6=py311h06a4308_0
  - ncurses=6.4=h6a678d5_0
  - nest-asyncio=1.6.0=py311h06a4308_0
  - openssl=3.0.14=h5eee18b_0
  - packaging=24.1=py311h06a4308_0
  - parso=0.8.3=pyhd3eb1b0_0
  - pexpect=4.8.0=pyhd3eb1b0_3
  - pip=24.2=py311h06a4308_0
  - platformdirs=3.10.0=py311h06a4308_0
  - prompt-toolkit=3.0.43=py311h06a4308_0
  - prompt_toolkit=3.0.43=hd3eb1b0_0
  - ptyprocess=0.7.0=pyhd3eb1b0_2
  - pure_eval=0.2.2=pyhd3eb1b0_0
  - pygments=2.15.1=py311h06a4308_1
  - python=3.11.9=h955ad1f_0
  - python-dateutil=2.9.0post0=py311h06a4308_2
  - pyzmq=25.1.2=py311h6a678d5_0
  - readline=8.2=h5eee18b_0
  - setuptools=72.1.0=py311h06a4308_0
  - six=1.16.0=pyhd3eb1b0_1
  - sqlite=3.45.3=h5eee18b_0
  - stack_data=0.2.0=pyhd3eb1b0_0
  - tk=8.6.14=h39e8969_0
  - tornado=6.4.1=py311h5eee18b_0
  - traitlets=5.14.3=py311h06a4308_0
  - typing_extensions=4.11.0=py311h06a4308_0
  - wcwidth=0.2.5=pyhd3eb1b0_0
  - wheel=0.43.0=py311h06a4308_0
  - xz=5.4.6=h5eee18b_1
  - zeromq=4.3.5=h6a678d5_0
  - zlib=1.2.13=h5eee18b_1
  - pip:
      - accelerate==0.33.0
      - aiohappyeyeballs==2.4.0
      - aiohttp==3.10.5
      - aiosignal==1.3.1
      - annotated-types==0.7.0
      - anyio==4.4.0
      - attrs==24.2.0
      - certifi==2024.7.4
      - charset-normalizer==3.3.2
      - click==8.1.7
      - datasets==2.21.0
      - dill==0.3.8
      - distro==1.9.0
      - docstring-parser==0.16
      - evaluate==0.4.2
      - filelock==3.15.4
      - fire==0.6.0
      - frozenlist==1.4.1
      - fsspec==2024.6.1
      - fuzzywuzzy==0.18.0
      - h11==0.14.0
      - httpcore==1.0.5
      - httpx==0.27.1
      - huggingface-hub==0.24.5
      - idna==3.7
      - jieba==0.42.1
      - jinja2==3.1.4
      - jiter==0.5.0
      - joblib==1.4.2
      - jsonschema==4.23.0
      - jsonschema-specifications==2023.12.1
      - llmlingua==0.2.2
      - markupsafe==2.1.5
      - mistral-common==1.3.4
      - mistral-inference==1.3.1
      - mpmath==1.3.0
      - multidict==6.0.5
      - multiprocess==0.70.16
      - networkx==3.3
      - nltk==3.8.1
      - numpy==1.26.4
      - nvidia-cublas-cu12==12.1.3.1
      - nvidia-cuda-cupti-cu12==12.1.105
      - nvidia-cuda-nvrtc-cu12==12.1.105
      - nvidia-cuda-runtime-cu12==12.1.105
      - nvidia-cudnn-cu12==9.1.0.70
      - nvidia-cufft-cu12==11.0.2.54
      - nvidia-curand-cu12==10.3.2.106
      - nvidia-cusolver-cu12==11.4.5.107
      - nvidia-cusparse-cu12==12.1.0.106
      - nvidia-nccl-cu12==2.20.5
      - nvidia-nvjitlink-cu12==12.6.20
      - nvidia-nvtx-cu12==12.1.105
      - openai==1.42.0
      - pandas==2.2.2
      - psutil==6.0.0
      - pyarrow==17.0.0
      - pydantic==2.8.2
      - pydantic-core==2.20.1
      - pytz==2024.1
      - pyyaml==6.0.2
      - referencing==0.35.1
      - regex==2024.7.24
      - requests==2.32.3
      - rouge==1.0.1
      - rpds-py==0.20.0
      - safetensors==0.4.4
      - sentencepiece==0.2.0
      - simple-parsing==0.1.5
      - sniffio==1.3.1
      - sympy==1.13.2
      - termcolor==2.4.0
      - tiktoken==0.7.0
      - tokenizers==0.19.1
      - torch==2.4.0
      - tqdm==4.66.5
      - transformers==4.44.0
      - triton==3.0.0
      - typing-extensions==4.12.2
      - tzdata==2024.1
      - urllib3==2.2.2
      - xformers==0.0.27.post2
      - xxhash==3.5.0
      - yarl==1.9.4
prefix: /home/test

Reproduction Steps

from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer

model = Transformer.from_folder("./models/mistral-7B-v0.1", device="cuda:7")
tokenizer = MistralTokenizer.from_file("./models/mistral-7B-v0.1/tokenizer.model").instruct_tokenizer.tokenizer

prompt = "What is the capital of germany? Answer:"
tokens = tokenizer.encode(prompt, bos=True, eos=False)
out_tokens, logprobs = generate([tokens], model, max_tokens=50, temperature=0)
result = tokenizer.decode(out_tokens[0])

Expected Behavior

I am getting the following error when trying to run above code:

ValueError: Attention bias and Query/Key/Value should be on the same device
  query.device: cuda:7
  attn_bias   : cuda:0

This seems related to facebookresearch/xformers#1064, couldn't figure out why this happens yet...

Additional Context

Stack trace
File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/mistral_inference/generate.py", line 82, in generate                       
    prelogits = model.forward(                                                                                                                           
                ^^^^^^^^^^^^^^                                                                                                                           
  File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/mistral_inference/transformer.py", line 276, in forward                    
    h = self.forward_partial(input_ids, seqlens, cache=cache)                                                                                            
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                            
  File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/mistral_inference/transformer.py", line 258, in forward_partial            
    h = layer(h, freqs_cis, cache_view)                                                                                                                  
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                  
  File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl              
    return self._call_impl(*args, **kwargs)                                                                                                              
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                              
  File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl                      
    return forward_call(*args, **kwargs)                                                                                                                 
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                 
  File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/mistral_inference/transformer.py", line 156, in forward                    
    r = self.attention.forward(self.attention_norm(x), freqs_cis, cache)                                                                                 
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                 
  File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/mistral_inference/transformer.py", line 100, in forward                    
    output = memory_efficient_attention(xq, key, val, None if cache is None else cache.mask)                                                             
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                             
  File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/xformers/ops/fmha/__init__.py", line 276, in memory_efficient_attention    
    return _memory_efficient_attention(                                                                                                                  
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                  
  File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/xformers/ops/fmha/__init__.py", line 395, in _memory_efficient_attention   
    return _memory_efficient_attention_forward(                                                                                                          
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                          
  File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/xformers/ops/fmha/__init__.py", line 411, in _memory_efficient_attention_for
ward                                                                                                                                                     
    inp.validate_inputs()                                                                                                                                
  File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/xformers/ops/fmha/common.py", line 145, in validate_inputs                 
    raise ValueError(
ValueError: Attention bias and Query/Key/Value should be on the same device
  query.device: cuda:7
  attn_bias   : cuda:0

Suggested Solutions

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions