@@ -110,7 +110,9 @@ def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor:
110110 self .kv_move_buffer = torch .empty (
111111 (page_num , page_size , self .layer_num , 2 * num_kv_head , self .head_dim ), dtype = self .dtype , device = "cuda"
112112 )
113- self ._buffer_mem_indexes_tensors = [torch .empty ((page_size ,), dtype = torch .int64 , device = "cpu" , pin_memory = True ) for _ in range (page_num ) ]
113+ self ._buffer_mem_indexes_tensors = [
114+ torch .empty ((page_size ,), dtype = torch .int64 , device = "cpu" , pin_memory = True ) for _ in range (page_num )
115+ ]
114116 return self .kv_move_buffer
115117
116118 def write_mem_to_page_kv_move_buffer (
@@ -122,11 +124,9 @@ def write_mem_to_page_kv_move_buffer(
122124 dp_world_size : int ,
123125 ):
124126 cur_page = self .kv_move_buffer [page_index ]
125- pin_mem_indexes = self ._buffer_mem_indexes_tensors [page_index ][0 : len (mem_indexes )]
127+ pin_mem_indexes = self ._buffer_mem_indexes_tensors [page_index ][0 : len (mem_indexes )]
126128 pin_mem_indexes .numpy ()[:] = mem_indexes
127- mem_indexes_gpu = pin_mem_indexes .cuda (
128- non_blocking = True
129- )
129+ mem_indexes_gpu = pin_mem_indexes .cuda (non_blocking = True )
130130 repeat_count = dp_world_size * self .kv_buffer .shape [2 ] // self .kv_move_buffer .shape [3 ]
131131 dp_mems = mem_managers [(dp_index * dp_world_size ) : ((dp_index + 1 ) * dp_world_size )]
132132 for tp_index in range (dp_world_size ):
@@ -153,11 +153,9 @@ def read_page_kv_move_buffer_to_mem(
153153 dp_world_size : int ,
154154 ):
155155 cur_page = self .kv_move_buffer [page_index ]
156- pin_mem_indexes = self ._buffer_mem_indexes_tensors [page_index ][0 : len (mem_indexes )]
156+ pin_mem_indexes = self ._buffer_mem_indexes_tensors [page_index ][0 : len (mem_indexes )]
157157 pin_mem_indexes .numpy ()[:] = mem_indexes
158- mem_indexes_gpu = pin_mem_indexes .cuda (
159- non_blocking = True
160- )
158+ mem_indexes_gpu = pin_mem_indexes .cuda (non_blocking = True )
161159 dp_mems = mem_managers [(dp_index * dp_world_size ) : ((dp_index + 1 ) * dp_world_size )]
162160 mem_indexes_gpu = torch .tensor (mem_indexes , dtype = torch .int64 , device = "cpu" , pin_memory = True ).cuda (
163161 non_blocking = True
0 commit comments