fix: propagate precision correctly to enable non-bf16 inference#165
fix: propagate precision correctly to enable non-bf16 inference#165Icedgarr wants to merge 3 commits intometavoiceio:mainfrom
precision correctly to enable non-bf16 inference#165Conversation
fam/llm/fast_model.py
Outdated
| # key, query, value projections for all heads, but in a batch | ||
| self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) | ||
| self.wo = nn.Linear(config.dim, config.dim, bias=False) | ||
| self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False, dtype=config.dtype) |
There was a problem hiding this comment.
why're these required? I think with the fix here, this shouldn't be needed?
There was a problem hiding this comment.
The following line was throwing an error due to the use of mixed types, q was float16 but k and v are bfloat16.
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
Error: TorchRuntimeError: Failed running call_function <built-in function scaled_dot_product_attention>(*(FakeTensor(..., device='cuda:0', size=(2, 16, s0, 128)), FakeTensor(..., device='cuda:0', size=(2, 16, 2048, 128), dtype=torch.float16), FakeTensor(..., device='cuda:0', size=(2, 16, 2048, 128), dtype=torch.float16)), **{'attn_mask': FakeTensor(..., device='cuda:0', size=(1, 1, s0, 2048), dtype=torch.bool), 'dropout_p': 0.0}): Expected query, key, and value to have the same dtype, but got query.dtype: float key.dtype: c10::Half and value.dtype: c10::Half instead.
However, I have just checked and this alone did not solve the issue, it worked after I run the code with the torch dynamo disabled as well (doing export TORCHDYNAMO_DISABLE=1). It may be some of the operations done to the key and value tensors (I suspect this one: k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) because it is the only one performed on k and v, but not q).
If you prefer, I can remove this change from the PR and if I or someone else find the root cause and a way to solve it we can create another PR.
There was a problem hiding this comment.
So queries are the right dtype but keys and values are not? That sounds like it might be related to kv-cache not being the right dtype ... but we seem to be setting it correctly here... did you already have a look there?
There was a problem hiding this comment.
That's right, I see that the kv-cache runs when I execute the code, so it is likely to be what changes the dtypes, which according to the code you reference should not happen. If you agree I'll remove this part of the PR and investigate a bit further tomorrow to try to fix this other issue.
precision correctly to enable non-bf16 inference
This reverts commit f01a9ce.
|
I have reverted the last commit since it was not required for this fix. |

This PR fixes some incompatibilities that I encountered when instantiating
TSSfromfam/llm/fast_inference.pywith older and less powerful GPUs (e.g. Google Colab T4 GPU).fam/llm/fast_inference_utils.pywas putting the model to thedevice(cuda) withdtype.bfloat16instead of using theprecisionparameter that contains the selected dtype (by defaultfloat16orbfloat16depending on the GPU architecture).The linear layer of the
Attentionclass infam/llm/fast_model.pywas also missing the dtype definition using the one provided in the config.