66
77from gllm .allocatorID import AllocatorID
88from gllm .sequence import Sequence
9- from gllm .dist_utils import get_pp_rank
109
1110
1211class MemoryManager ():
@@ -42,7 +41,7 @@ def __init__(self, gpu_memory_util: float, num_layers: int, dtype: torch.dtype,
4241 logger .info (f'Allocate { self .num_pages } pages ({ self .page_size } tokens/page)' )
4342
4443 self .segment = Segment (self .num_layers , self .num_pages ,
45- self .page_size , self .kv_head_num , self .kv_head_dim , self . dtype )
44+ self .page_size , self .kv_head_num , self .kv_head_dim )
4645
4746 def batch_store (self , layer_idx : int , k_cache : torch .Tensor , v_cache : torch .Tensor , slot_mapping_tensor : torch .Tensor ):
4847 from gllm import _custom_ops as ops
@@ -79,18 +78,17 @@ def __init__(self,
7978 num_pages : int ,
8079 page_size : int ,
8180 kv_head_num : int ,
82- kv_head_dim : int ,
83- dtype : torch .dtype ):
81+ kv_head_dim : int ):
8482 self .num_layers = num_layers
8583 self .num_pages = num_pages
8684 self .page_size = page_size
8785 self .kv_head_num = kv_head_num
8886 self .kv_head_dim = kv_head_dim
8987 # We don't need zero initialization here
9088 self .k_cache = [torch .ones (
91- (num_pages , page_size , kv_head_num , kv_head_dim ), dtype = dtype , device = 'cuda' ) for _ in range (num_layers )]
89+ (num_pages , page_size , kv_head_num , kv_head_dim )) for _ in range (num_layers )]
9290 self .v_cache = [torch .ones (
93- (num_pages , page_size , kv_head_num , kv_head_dim ), dtype = dtype , device = 'cuda' ) for _ in range (num_layers )]
91+ (num_pages , page_size , kv_head_num , kv_head_dim )) for _ in range (num_layers )]
9492 self .allocatorID = AllocatorID (0 , num_pages - 1 )
9593
9694 def allocate (self ):
@@ -113,7 +111,7 @@ def __init__(self, *args, **kwargs):
113111 super ().__init__ (* args , ** kwargs )
114112
115113 del self .segment
116- self .segment = PrefixSegment (self .num_layers , self .num_pages , self .page_size , self .kv_head_num , self .kv_head_dim , self . dtype )
114+ self .segment = PrefixSegment (self .num_layers , self .num_pages , self .page_size , self .kv_head_num , self .kv_head_dim )
117115
118116 # for prefill stage
119117 self .num_allocated_pages = 0
0 commit comments