Skip to content

Commit 9d8b44f

Browse files
IanWood1IanNod
authored andcommitted
[sharktank] Fix attention dtype (nod-ai#1243)
Fixes bug in refactor (nod-ai#1098) that removed specifying different dtypes for cache vs attention. --------- Signed-off-by: Ian Wood <[email protected]>
1 parent 4c5d73a commit 9d8b44f

File tree

2 files changed

+1
-4
lines changed

2 files changed

+1
-4
lines changed

sharktank/sharktank/layers/paged_llama_attention_block.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def __init__(
3131
head_dim: int,
3232
head_count_kv: int,
3333
rms_epsilon: float,
34-
attention_dtype: Optional[torch.dtype] = None,
3534
attention_kernel: str = "torch",
3635
attention_scale: Optional[float] = None,
3736
softcap: Optional[float] = None,
@@ -45,15 +44,14 @@ def __init__(
4544
attn_head_dim=head_dim,
4645
block_seq_stride=cache.block_seq_stride,
4746
cache_dtype=cache.cache_dtype,
48-
attn_dtype=attention_dtype,
47+
attn_dtype=cache.attn_dtype,
4948
device=cache.device,
5049
shard_count=cache.shard_count,
5150
)
5251
self.block_index = block_index
5352
self.head_count = head_count
5453
self.head_dim = head_dim
5554
self.head_count_kv = head_count_kv
56-
self.attention_dtype = attention_dtype
5755
self.attention_kernel = attention_kernel
5856
self.attention_scale = attention_scale
5957
self.softcap = softcap

sharktank/sharktank/models/llm/llm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,6 @@ def __init__(
258258
head_dim=config.hp.attn_head_dim,
259259
head_count_kv=config.hp.attention_head_count_kv,
260260
rms_epsilon=config.hp.attention_layer_norm_rms_epsilon,
261-
attention_dtype=config.attention_dtype,
262261
attention_kernel=attention_kernel,
263262
fake_quant=fake_quant,
264263
softcap=config.hp.attention_softcap,

0 commit comments

Comments
 (0)