@@ -19,6 +19,7 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
19
19
gemm_half_q_half (x , q_handle , output , force_cuda )
20
20
return output .view (output_shape )
21
21
22
+
22
23
def ext_make_q_matrix (w : dict , temp_dq , key : str = None ):
23
24
"""
24
25
Create Q matrix
@@ -60,63 +61,46 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
60
61
temp_dq ,
61
62
)
62
63
64
+
63
65
def temp_dq_size (inout_product ):
64
66
return inout_product * 2 + 128
65
67
66
- def temp_fwd_size (outfeatures , max_batch_tokens ):
67
- return outfeatures * max_batch_tokens * 4 + 128
68
68
69
- def scratch_space_fixed (inout_product , outfeatures , max_batch_tokens ):
70
- return temp_dq_size (inout_product ) + temp_fwd_size (outfeatures , max_batch_tokens )
69
+ def _elements (size_bytes ):
70
+ size_bytes = (size_bytes + 127 ) & - 128 # round up to nearest multiple of 128
71
+ return size_bytes // 2
72
+
71
73
72
- class ExLlamaV2DeviceTensors :
74
+ class ExLlamaV2DeviceTensor :
73
75
def __init__ (self , device , scratch_bytes ):
74
76
self .device = device
75
- self .scratch_bytes = scratch_bytes
76
- self .scratch = None
77
-
78
- def prepare (self ):
79
- print_rank_n (f"Allocating { self .scratch_bytes // (1024 * 1024 )} MiB for exllama v2 scratch space" )
80
- self .scratch = torch .empty ((self .scratch_bytes // 2 ,), dtype = torch .half , device = self .device )
77
+ print_rank_n (f"Allocating { scratch_bytes // (1024 * 1024 )} MiB for exllama v2 scratch space" )
78
+ self .scratch = torch .empty (
79
+ _elements (scratch_bytes ), dtype = torch .half , device = self .device
80
+ )
81
81
82
82
def get_scratch_slice (self , size_bytes ):
83
- if self . scratch is None :
84
- self .prepare ()
83
+ size_half = _elements ( size_bytes )
84
+ return self .scratch [: size_half ]
85
85
86
- size_bytes = ((size_bytes + 127 ) // 128 ) * 128
87
- size_half = size_bytes // 2
88
- scratch_slice = self .scratch .narrow (0 , 0 , size_half )
89
- return scratch_slice
90
86
91
- # Max number of output features, used by temp_fwd_size calculation
92
- MAX_OUT_FEATURES = 1
93
- # Max of (infeatures * outfeatures), used by temp_dq_size calculation
94
- MAX_INOUT_PRODUCT = 1
95
87
# DEVICE_TENSOR is a cuda buffer used by cublas gemm when M > 50
96
88
DEVICE_TENSOR = None
97
89
DEVICE = None
90
+ # Max of (infeatures * outfeatures), used by temp_dq_size calculation
91
+ MAX_INOUT_PRODUCT = 1
92
+
98
93
99
94
def set_device (device ):
100
- global DEVICE
95
+ global DEVICE , DEVICE_TENSOR , MAX_INOUT_PRODUCT
101
96
DEVICE = device
97
+ DEVICE_TENSOR = ExLlamaV2DeviceTensor (DEVICE , temp_dq_size (MAX_INOUT_PRODUCT ))
102
98
103
- def create_exllama_buffers (max_batch_tokens : int ):
104
- global DEVICE , DEVICE_TENSOR , MAX_OUT_FEATURES , MAX_INOUT_PRODUCT
105
-
106
- assert DEVICE is not None , "call set_device first"
107
-
108
- DEVICE_TENSOR = ExLlamaV2DeviceTensors (
109
- DEVICE ,
110
- scratch_space_fixed (
111
- MAX_INOUT_PRODUCT ,
112
- MAX_OUT_FEATURES ,
113
- max_batch_tokens ,
114
- ))
115
99
116
100
class Ex4bitLinearV2 (nn .Module ):
117
101
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
118
102
def __init__ (self , qweight , qzeros , scales , g_idx , bias , bits , groupsize ):
119
- global MAX_OUT_FEATURES , MAX_INOUT_PRODUCT
103
+ global MAX_INOUT_PRODUCT
120
104
super ().__init__ ()
121
105
assert bits == 4
122
106
@@ -134,9 +118,8 @@ def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
134
118
assert self .height % 32 == 0
135
119
assert self .width % 32 == 0
136
120
137
- # Update max outfeatures & inout_product so far for later call to create_exllama_buffers
138
- MAX_OUT_FEATURES = max (MAX_OUT_FEATURES , self .width )
139
- MAX_INOUT_PRODUCT = max (MAX_INOUT_PRODUCT , self .width * self .height )
121
+ # Update max outfeatures & inout_product so far for later call to set_device
122
+ MAX_INOUT_PRODUCT = max (MAX_INOUT_PRODUCT , self .width * self .height )
140
123
141
124
def post_init (self ):
142
125
global DEVICE_TENSOR
0 commit comments