Skip to content

Commit 5699ea1

Browse files
authored
[sharktank] Fix attention dtype (#1243)
Fixes bug in refactor (#1098) that removed specifying different dtypes for cache vs attention. --------- Signed-off-by: Ian Wood <[email protected]>
1 parent 244a5e9 commit 5699ea1

File tree

10 files changed

+34
-25
lines changed

10 files changed

+34
-25
lines changed

sharktank/sharktank/export_layer/export_kv_cache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def main():
6565
attn_head_count=attn_head_count,
6666
attn_head_dim=attn_head_dim,
6767
shard_count=args.sharding,
68-
dtype=torch.float32,
68+
cache_dtype=torch.float32,
69+
attn_dtype=torch.float32,
6970
device=None,
7071
)
7172

sharktank/sharktank/layers/paged_attention.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def __init__(
6060
attn_head_dim: int,
6161
cache_partition_count: int = 2,
6262
block_seq_stride: int = 16,
63-
dtype: torch.dtype = torch.float32,
63+
cache_dtype: torch.dtype = torch.float32,
64+
attn_dtype: torch.dtype = torch.float32,
6465
device: Optional[torch.device] = None,
6566
shard_count: int = 1,
6667
):
@@ -85,7 +86,8 @@ def __init__(
8586
]
8687
self.page_slab_flat_dim = math.prod(self.sub_page_dims)
8788
self.device = device
88-
self.dtype = dtype
89+
self.cache_dtype = cache_dtype
90+
self.attn_dtype = attn_dtype
8991

9092
def unflatten_page_table(
9193
self, state: list[Union[torch.Tensor, SplitPrimitiveTensor]]
@@ -146,7 +148,7 @@ def allocate(
146148
shards = [
147149
torch.empty(
148150
[page_count, self.page_slab_flat_dim],
149-
dtype=self.dtype,
151+
dtype=self.cache_dtype,
150152
device=self.device,
151153
)
152154
for _ in range(self.shard_count)
@@ -356,18 +358,18 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
356358

357359
# Fake quant is already dequantized when stored in the cache.
358360
if cache_quantizer and not fake_quant:
359-
k = cache_quantizer.dequantize_raw_tensor(k, self.dtype, name="xk_deq")
360-
v = cache_quantizer.dequantize_raw_tensor(v, self.dtype, name="xv_deq")
361+
k = cache_quantizer.dequantize_raw_tensor(k, self.attn_dtype, name="xk_deq")
362+
v = cache_quantizer.dequantize_raw_tensor(v, self.attn_dtype, name="xv_deq")
361363

362364
q = q.transpose(1, 2)
363365
k = k.transpose(1, 2)
364366
v = v.transpose(1, 2)
365367

366-
q = ops.to(q, dtype=self.dtype)
367-
k = ops.to(k, dtype=self.dtype)
368-
v = ops.to(v, dtype=self.dtype)
368+
q = ops.to(q, dtype=self.attn_dtype)
369+
k = ops.to(k, dtype=self.attn_dtype)
370+
v = ops.to(v, dtype=self.attn_dtype)
369371
if mask is not None:
370-
mask = ops.to(mask, dtype=self.dtype)
372+
mask = ops.to(mask, dtype=self.attn_dtype)
371373

372374
# Decomposed
373375
if attention_kernel == "decomposed":

sharktank/sharktank/layers/paged_llama_attention_block.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def __init__(
3131
head_dim: int,
3232
head_count_kv: int,
3333
rms_epsilon: float,
34-
attention_dtype: Optional[torch.dtype] = None,
3534
attention_kernel: str = "torch",
3635
attention_scale: Optional[float] = None,
3736
softcap: Optional[float] = None,
@@ -44,15 +43,15 @@ def __init__(
4443
attn_head_count=head_count_kv,
4544
attn_head_dim=head_dim,
4645
block_seq_stride=cache.block_seq_stride,
47-
dtype=cache.dtype,
46+
cache_dtype=cache.cache_dtype,
47+
attn_dtype=cache.attn_dtype,
4848
device=cache.device,
4949
shard_count=cache.shard_count,
5050
)
5151
self.block_index = block_index
5252
self.head_count = head_count
5353
self.head_dim = head_dim
5454
self.head_count_kv = head_count_kv
55-
self.attention_dtype = attention_dtype
5655
self.attention_kernel = attention_kernel
5756
self.attention_scale = attention_scale
5857
self.softcap = softcap

sharktank/sharktank/models/llm/llm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,6 @@ def __init__(
258258
head_dim=config.hp.attn_head_dim,
259259
head_count_kv=config.hp.attention_head_count_kv,
260260
rms_epsilon=config.hp.attention_layer_norm_rms_epsilon,
261-
attention_dtype=config.attention_dtype,
262261
attention_kernel=attention_kernel,
263262
fake_quant=fake_quant,
264263
softcap=config.hp.attention_softcap,

sharktank/sharktank/utils/create_cache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def create_paged_kv_cache(config: LlamaModelConfig) -> PagedAttention:
2020
cache_partition_count=2, # One for each of K/V.
2121
block_seq_stride=config.block_seq_stride,
2222
device=config.device,
23-
dtype=dtype,
23+
cache_dtype=dtype,
24+
attn_dtype=config.attention_dtype,
2425
shard_count=config.tensor_parallelism_size,
2526
)

sharktank/tests/layers/kv_cache_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def test_paged(dtype: torch.dtype):
3333
transformer_block_count=transformer_block_count,
3434
attn_head_count=attn_head_count,
3535
attn_head_dim=attn_head_dim,
36-
dtype=dtype,
36+
cache_dtype=dtype,
37+
attn_dtype=dtype,
3738
device=None,
3839
)
3940

@@ -142,7 +143,8 @@ def test_sharded_paged():
142143
attn_head_count=attn_head_count,
143144
attn_head_dim=attn_head_dim,
144145
shard_count=shard_count,
145-
dtype=torch.float32,
146+
cache_dtype=torch.float32,
147+
attn_dtype=torch.float32,
146148
device=None,
147149
)
148150

sharktank/tests/layers/paged_llama_attention_block_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def testExportNondecomposed(self):
6262
attn_head_dim=self.attention_head_dim,
6363
cache_partition_count=self.cache_partition_count,
6464
block_seq_stride=self.block_seq_stride,
65-
dtype=dtype,
65+
cache_dtype=dtype,
66+
attn_dtype=dtype,
6667
)
6768

6869
cache_state = cache.allocate(self.page_count)

sharktank/tests/layers/sharded_paged_kv_cache_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def setUp(self):
3838
block_seq_stride=self.block_seq_stride,
3939
attn_head_dim=self.attn_head_dim,
4040
cache_partition_count=self.cache_partition_count,
41-
dtype=self.dtype,
41+
cache_dtype=self.dtype,
42+
attn_dtype=self.dtype,
4243
)
4344
self.sharded_cache = PagedAttention(
4445
shard_count=self.shard_count,
@@ -47,7 +48,8 @@ def setUp(self):
4748
block_seq_stride=self.block_seq_stride,
4849
attn_head_dim=self.attn_head_dim,
4950
cache_partition_count=self.cache_partition_count,
50-
dtype=self.dtype,
51+
cache_dtype=self.dtype,
52+
attn_dtype=self.dtype,
5153
)
5254

5355
def make_unsharded_and_sharded_equal_cache_states(

sharktank/tests/layers/sharded_paged_llama_attention_block.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,17 @@ def make_paged_kv_cache(shard_count: int) -> PagedAttention:
6464
attn_head_dim=self.attention_head_dim,
6565
cache_partition_count=self.cache_partition_count,
6666
block_seq_stride=self.block_seq_stride,
67-
dtype=dtype,
67+
cache_dtype=dtype,
68+
attn_dtype=dtype,
6869
shard_count=shard_count,
6970
)
7071

7172
cache = make_paged_kv_cache(shard_count=1)
7273
sharded_cache = make_paged_kv_cache(shard_count=self.shard_count)
7374

74-
def make_unsharded_and_sharded_equal_cache_states() -> tuple[
75-
list[torch.Tensor], list[SplitPrimitiveTensor]
76-
]:
75+
def make_unsharded_and_sharded_equal_cache_states() -> (
76+
tuple[list[torch.Tensor], list[SplitPrimitiveTensor]]
77+
):
7778
cache_state = cache.allocate(self.page_count)
7879
cache_state[0] = make_rand_torch(cache_state[0].shape, dtype=dtype)
7980
sharded_cache_state = sharded_cache.shard_state(deepcopy(cache_state))

sharktank/tests/models/llama/attention_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def test(self):
7070
cache_partition_count=2, # One for each of K/V.
7171
block_seq_stride=block_seq_stride,
7272
device="cpu",
73-
dtype=torch.float32,
73+
cache_dtype=torch.float32,
74+
attn_dtype=torch.float32,
7475
)
7576
attention_block = AttentionFFNBlock(
7677
theta=attention_block_theta,

0 commit comments

Comments
 (0)