-
Notifications
You must be signed in to change notification settings - Fork 988
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Python -VV
Python 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0]Pip Freeze
conda env export
name: test
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- asttokens=2.0.5=pyhd3eb1b0_0
- bzip2=1.0.8=h5eee18b_6
- ca-certificates=2024.7.2=h06a4308_0
- comm=0.2.1=py311h06a4308_0
- debugpy=1.6.7=py311h6a678d5_0
- decorator=5.1.1=pyhd3eb1b0_0
- executing=0.8.3=pyhd3eb1b0_0
- ipykernel=6.28.0=py311h06a4308_0
- ipython=8.25.0=py311h06a4308_0
- jedi=0.19.1=py311h06a4308_0
- jupyter_client=8.6.0=py311h06a4308_0
- jupyter_core=5.7.2=py311h06a4308_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.4.4=h6a678d5_1
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libsodium=1.0.18=h7b6447c_0
- libstdcxx-ng=11.2.0=h1234567_1
- libuuid=1.41.5=h5eee18b_0
- matplotlib-inline=0.1.6=py311h06a4308_0
- ncurses=6.4=h6a678d5_0
- nest-asyncio=1.6.0=py311h06a4308_0
- openssl=3.0.14=h5eee18b_0
- packaging=24.1=py311h06a4308_0
- parso=0.8.3=pyhd3eb1b0_0
- pexpect=4.8.0=pyhd3eb1b0_3
- pip=24.2=py311h06a4308_0
- platformdirs=3.10.0=py311h06a4308_0
- prompt-toolkit=3.0.43=py311h06a4308_0
- prompt_toolkit=3.0.43=hd3eb1b0_0
- ptyprocess=0.7.0=pyhd3eb1b0_2
- pure_eval=0.2.2=pyhd3eb1b0_0
- pygments=2.15.1=py311h06a4308_1
- python=3.11.9=h955ad1f_0
- python-dateutil=2.9.0post0=py311h06a4308_2
- pyzmq=25.1.2=py311h6a678d5_0
- readline=8.2=h5eee18b_0
- setuptools=72.1.0=py311h06a4308_0
- six=1.16.0=pyhd3eb1b0_1
- sqlite=3.45.3=h5eee18b_0
- stack_data=0.2.0=pyhd3eb1b0_0
- tk=8.6.14=h39e8969_0
- tornado=6.4.1=py311h5eee18b_0
- traitlets=5.14.3=py311h06a4308_0
- typing_extensions=4.11.0=py311h06a4308_0
- wcwidth=0.2.5=pyhd3eb1b0_0
- wheel=0.43.0=py311h06a4308_0
- xz=5.4.6=h5eee18b_1
- zeromq=4.3.5=h6a678d5_0
- zlib=1.2.13=h5eee18b_1
- pip:
- accelerate==0.33.0
- aiohappyeyeballs==2.4.0
- aiohttp==3.10.5
- aiosignal==1.3.1
- annotated-types==0.7.0
- anyio==4.4.0
- attrs==24.2.0
- certifi==2024.7.4
- charset-normalizer==3.3.2
- click==8.1.7
- datasets==2.21.0
- dill==0.3.8
- distro==1.9.0
- docstring-parser==0.16
- evaluate==0.4.2
- filelock==3.15.4
- fire==0.6.0
- frozenlist==1.4.1
- fsspec==2024.6.1
- fuzzywuzzy==0.18.0
- h11==0.14.0
- httpcore==1.0.5
- httpx==0.27.1
- huggingface-hub==0.24.5
- idna==3.7
- jieba==0.42.1
- jinja2==3.1.4
- jiter==0.5.0
- joblib==1.4.2
- jsonschema==4.23.0
- jsonschema-specifications==2023.12.1
- llmlingua==0.2.2
- markupsafe==2.1.5
- mistral-common==1.3.4
- mistral-inference==1.3.1
- mpmath==1.3.0
- multidict==6.0.5
- multiprocess==0.70.16
- networkx==3.3
- nltk==3.8.1
- numpy==1.26.4
- nvidia-cublas-cu12==12.1.3.1
- nvidia-cuda-cupti-cu12==12.1.105
- nvidia-cuda-nvrtc-cu12==12.1.105
- nvidia-cuda-runtime-cu12==12.1.105
- nvidia-cudnn-cu12==9.1.0.70
- nvidia-cufft-cu12==11.0.2.54
- nvidia-curand-cu12==10.3.2.106
- nvidia-cusolver-cu12==11.4.5.107
- nvidia-cusparse-cu12==12.1.0.106
- nvidia-nccl-cu12==2.20.5
- nvidia-nvjitlink-cu12==12.6.20
- nvidia-nvtx-cu12==12.1.105
- openai==1.42.0
- pandas==2.2.2
- psutil==6.0.0
- pyarrow==17.0.0
- pydantic==2.8.2
- pydantic-core==2.20.1
- pytz==2024.1
- pyyaml==6.0.2
- referencing==0.35.1
- regex==2024.7.24
- requests==2.32.3
- rouge==1.0.1
- rpds-py==0.20.0
- safetensors==0.4.4
- sentencepiece==0.2.0
- simple-parsing==0.1.5
- sniffio==1.3.1
- sympy==1.13.2
- termcolor==2.4.0
- tiktoken==0.7.0
- tokenizers==0.19.1
- torch==2.4.0
- tqdm==4.66.5
- transformers==4.44.0
- triton==3.0.0
- typing-extensions==4.12.2
- tzdata==2024.1
- urllib3==2.2.2
- xformers==0.0.27.post2
- xxhash==3.5.0
- yarl==1.9.4
prefix: /home/testReproduction Steps
from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
model = Transformer.from_folder("./models/mistral-7B-v0.1", device="cuda:7")
tokenizer = MistralTokenizer.from_file("./models/mistral-7B-v0.1/tokenizer.model").instruct_tokenizer.tokenizer
prompt = "What is the capital of germany? Answer:"
tokens = tokenizer.encode(prompt, bos=True, eos=False)
out_tokens, logprobs = generate([tokens], model, max_tokens=50, temperature=0)
result = tokenizer.decode(out_tokens[0])Expected Behavior
I am getting the following error when trying to run above code:
ValueError: Attention bias and Query/Key/Value should be on the same device
query.device: cuda:7
attn_bias : cuda:0
This seems related to facebookresearch/xformers#1064, couldn't figure out why this happens yet...
Additional Context
Stack trace
File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/mistral_inference/generate.py", line 82, in generate
prelogits = model.forward(
^^^^^^^^^^^^^^
File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/mistral_inference/transformer.py", line 276, in forward
h = self.forward_partial(input_ids, seqlens, cache=cache)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/mistral_inference/transformer.py", line 258, in forward_partial
h = layer(h, freqs_cis, cache_view)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/mistral_inference/transformer.py", line 156, in forward
r = self.attention.forward(self.attention_norm(x), freqs_cis, cache)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/mistral_inference/transformer.py", line 100, in forward
output = memory_efficient_attention(xq, key, val, None if cache is None else cache.mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/xformers/ops/fmha/__init__.py", line 276, in memory_efficient_attention
return _memory_efficient_attention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/xformers/ops/fmha/__init__.py", line 395, in _memory_efficient_attention
return _memory_efficient_attention_forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/xformers/ops/fmha/__init__.py", line 411, in _memory_efficient_attention_for
ward
inp.validate_inputs()
File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/xformers/ops/fmha/common.py", line 145, in validate_inputs
raise ValueError(
ValueError: Attention bias and Query/Key/Value should be on the same device
query.device: cuda:7
attn_bias : cuda:0
Suggested Solutions
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working