@@ -19,6 +19,7 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
1919 gemm_half_q_half (x , q_handle , output , force_cuda )
2020 return output .view (output_shape )
2121
22+
2223def ext_make_q_matrix (w : dict , temp_dq , key : str = None ):
2324 """
2425 Create Q matrix
@@ -60,63 +61,46 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
6061 temp_dq ,
6162 )
6263
64+
6365def temp_dq_size (inout_product ):
6466 return inout_product * 2 + 128
6567
66- def temp_fwd_size (outfeatures , max_batch_tokens ):
67- return outfeatures * max_batch_tokens * 4 + 128
6868
69- def scratch_space_fixed (inout_product , outfeatures , max_batch_tokens ):
70- return temp_dq_size (inout_product ) + temp_fwd_size (outfeatures , max_batch_tokens )
69+ def _elements (size_bytes ):
70+ size_bytes = (size_bytes + 127 ) & - 128 # round up to nearest multiple of 128
71+ return size_bytes // 2
72+
7173
72- class ExLlamaV2DeviceTensors :
74+ class ExLlamaV2DeviceTensor :
7375 def __init__ (self , device , scratch_bytes ):
7476 self .device = device
75- self .scratch_bytes = scratch_bytes
76- self .scratch = None
77-
78- def prepare (self ):
79- print_rank_n (f"Allocating { self .scratch_bytes // (1024 * 1024 )} MiB for exllama v2 scratch space" )
80- self .scratch = torch .empty ((self .scratch_bytes // 2 ,), dtype = torch .half , device = self .device )
77+ print_rank_n (f"Allocating { scratch_bytes // (1024 * 1024 )} MiB for exllama v2 scratch space" )
78+ self .scratch = torch .empty (
79+ _elements (scratch_bytes ), dtype = torch .half , device = self .device
80+ )
8181
8282 def get_scratch_slice (self , size_bytes ):
83- if self . scratch is None :
84- self .prepare ()
83+ size_half = _elements ( size_bytes )
84+ return self .scratch [: size_half ]
8585
86- size_bytes = ((size_bytes + 127 ) // 128 ) * 128
87- size_half = size_bytes // 2
88- scratch_slice = self .scratch .narrow (0 , 0 , size_half )
89- return scratch_slice
9086
91- # Max number of output features, used by temp_fwd_size calculation
92- MAX_OUT_FEATURES = 1
93- # Max of (infeatures * outfeatures), used by temp_dq_size calculation
94- MAX_INOUT_PRODUCT = 1
9587# DEVICE_TENSOR is a cuda buffer used by cublas gemm when M > 50
9688DEVICE_TENSOR = None
9789DEVICE = None
90+ # Max of (infeatures * outfeatures), used by temp_dq_size calculation
91+ MAX_INOUT_PRODUCT = 1
92+
9893
9994def set_device (device ):
100- global DEVICE
95+ global DEVICE , DEVICE_TENSOR , MAX_INOUT_PRODUCT
10196 DEVICE = device
97+ DEVICE_TENSOR = ExLlamaV2DeviceTensor (DEVICE , temp_dq_size (MAX_INOUT_PRODUCT ))
10298
103- def create_exllama_buffers (max_batch_tokens : int ):
104- global DEVICE , DEVICE_TENSOR , MAX_OUT_FEATURES , MAX_INOUT_PRODUCT
105-
106- assert DEVICE is not None , "call set_device first"
107-
108- DEVICE_TENSOR = ExLlamaV2DeviceTensors (
109- DEVICE ,
110- scratch_space_fixed (
111- MAX_INOUT_PRODUCT ,
112- MAX_OUT_FEATURES ,
113- max_batch_tokens ,
114- ))
11599
116100class Ex4bitLinearV2 (nn .Module ):
117101 """Linear layer implementation with per-group 4-bit quantization of the weights"""
118102 def __init__ (self , qweight , qzeros , scales , g_idx , bias , bits , groupsize ):
119- global MAX_OUT_FEATURES , MAX_INOUT_PRODUCT
103+ global MAX_INOUT_PRODUCT
120104 super ().__init__ ()
121105 assert bits == 4
122106
@@ -134,9 +118,8 @@ def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
134118 assert self .height % 32 == 0
135119 assert self .width % 32 == 0
136120
137- # Update max outfeatures & inout_product so far for later call to create_exllama_buffers
138- MAX_OUT_FEATURES = max (MAX_OUT_FEATURES , self .width )
139- MAX_INOUT_PRODUCT = max (MAX_INOUT_PRODUCT , self .width * self .height )
121+ # Update max outfeatures & inout_product so far for later call to set_device
122+ MAX_INOUT_PRODUCT = max (MAX_INOUT_PRODUCT , self .width * self .height )
140123
141124 def post_init (self ):
142125 global DEVICE_TENSOR
0 commit comments