Skip to content

Commit 1b0a912

Browse files
authored
Remove the redundant dtype setting (#61)
1 parent 8b1b118 commit 1b0a912

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

gllm/memory_manager.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from gllm.allocatorID import AllocatorID
88
from gllm.sequence import Sequence
9-
from gllm.dist_utils import get_pp_rank
109

1110

1211
class MemoryManager():
@@ -42,7 +41,7 @@ def __init__(self, gpu_memory_util: float, num_layers: int, dtype: torch.dtype,
4241
logger.info(f'Allocate {self.num_pages} pages ({self.page_size} tokens/page)')
4342

4443
self.segment = Segment(self.num_layers, self.num_pages,
45-
self.page_size, self.kv_head_num, self.kv_head_dim, self.dtype)
44+
self.page_size, self.kv_head_num, self.kv_head_dim)
4645

4746
def batch_store(self, layer_idx: int, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping_tensor: torch.Tensor):
4847
from gllm import _custom_ops as ops
@@ -79,18 +78,17 @@ def __init__(self,
7978
num_pages: int,
8079
page_size: int,
8180
kv_head_num: int,
82-
kv_head_dim: int,
83-
dtype: torch.dtype):
81+
kv_head_dim: int):
8482
self.num_layers = num_layers
8583
self.num_pages = num_pages
8684
self.page_size = page_size
8785
self.kv_head_num = kv_head_num
8886
self.kv_head_dim = kv_head_dim
8987
# We don't need zero initialization here
9088
self.k_cache = [torch.ones(
91-
(num_pages, page_size, kv_head_num, kv_head_dim), dtype=dtype, device='cuda') for _ in range(num_layers)]
89+
(num_pages, page_size, kv_head_num, kv_head_dim)) for _ in range(num_layers)]
9290
self.v_cache = [torch.ones(
93-
(num_pages, page_size, kv_head_num, kv_head_dim), dtype=dtype, device='cuda') for _ in range(num_layers)]
91+
(num_pages, page_size, kv_head_num, kv_head_dim)) for _ in range(num_layers)]
9492
self.allocatorID = AllocatorID(0, num_pages-1)
9593

9694
def allocate(self):
@@ -113,7 +111,7 @@ def __init__(self, *args, **kwargs):
113111
super().__init__(*args, **kwargs)
114112

115113
del self.segment
116-
self.segment = PrefixSegment(self.num_layers, self.num_pages, self.page_size, self.kv_head_num, self.kv_head_dim, self.dtype)
114+
self.segment = PrefixSegment(self.num_layers, self.num_pages, self.page_size, self.kv_head_num, self.kv_head_dim)
117115

118116
# for prefill stage
119117
self.num_allocated_pages = 0

gllm/model_loader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def get_model_type(self):
102102

103103
def load_model(self, mp_load_progress=None):
104104
model_type = self.get_model_type()
105+
106+
logger.info(f'Set default dtype: {self.dtype}')
105107
torch.set_default_dtype(self.dtype)
106108

107109
if self.load_format == 'auto':

0 commit comments

Comments
 (0)