@@ -60,7 +60,8 @@ def __init__(
6060 attn_head_dim : int ,
6161 cache_partition_count : int = 2 ,
6262 block_seq_stride : int = 16 ,
63- dtype : torch .dtype = torch .float32 ,
63+ cache_dtype : torch .dtype = torch .float32 ,
64+ attn_dtype : torch .dtype = torch .float32 ,
6465 device : Optional [torch .device ] = None ,
6566 shard_count : int = 1 ,
6667 ):
@@ -85,7 +86,8 @@ def __init__(
8586 ]
8687 self .page_slab_flat_dim = math .prod (self .sub_page_dims )
8788 self .device = device
88- self .dtype = dtype
89+ self .cache_dtype = cache_dtype
90+ self .attn_dtype = attn_dtype
8991
9092 def unflatten_page_table (
9193 self , state : list [Union [torch .Tensor , SplitPrimitiveTensor ]]
@@ -146,7 +148,7 @@ def allocate(
146148 shards = [
147149 torch .empty (
148150 [page_count , self .page_slab_flat_dim ],
149- dtype = self .dtype ,
151+ dtype = self .cache_dtype ,
150152 device = self .device ,
151153 )
152154 for _ in range (self .shard_count )
@@ -356,18 +358,18 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
356358
357359 # Fake quant is already dequantized when stored in the cache.
358360 if cache_quantizer and not fake_quant :
359- k = cache_quantizer .dequantize_raw_tensor (k , self .dtype , name = "xk_deq" )
360- v = cache_quantizer .dequantize_raw_tensor (v , self .dtype , name = "xv_deq" )
361+ k = cache_quantizer .dequantize_raw_tensor (k , self .attn_dtype , name = "xk_deq" )
362+ v = cache_quantizer .dequantize_raw_tensor (v , self .attn_dtype , name = "xv_deq" )
361363
362364 q = q .transpose (1 , 2 )
363365 k = k .transpose (1 , 2 )
364366 v = v .transpose (1 , 2 )
365367
366- q = ops .to (q , dtype = self .dtype )
367- k = ops .to (k , dtype = self .dtype )
368- v = ops .to (v , dtype = self .dtype )
368+ q = ops .to (q , dtype = self .attn_dtype )
369+ k = ops .to (k , dtype = self .attn_dtype )
370+ v = ops .to (v , dtype = self .attn_dtype )
369371 if mask is not None :
370- mask = ops .to (mask , dtype = self .dtype )
372+ mask = ops .to (mask , dtype = self .attn_dtype )
371373
372374 # Decomposed
373375 if attention_kernel == "decomposed" :
0 commit comments