1010
1111QUANTIZE_CONFIG_FILENAME = "quantize_config.json"
1212
13+ def unpack (x , dim , bits = 4 ):
14+ return unpack_row (x , bits ) if dim == 0 else unpack_col (x , bits )
15+
16+ def unpack_col (x , bits ):
17+ mask = 2 ** bits - 1
18+ pack_size = 32 // bits
19+ unpacked_x = torch .zeros ((x .shape [0 ], x .shape [1 ]* pack_size ), dtype = torch .int )
20+ for i in range (pack_size ):
21+ unpacked_x [:, i ::pack_size ] = (x >> (i * bits )) & (mask )
22+ return unpacked_x
23+
24+ def unpack_row (x , bits ):
25+ mask = 2 ** bits - 1
26+ pack_size = 32 // bits
27+ unpacked_x = torch .zeros ((x .shape [0 ]* pack_size , x .shape [1 ]), dtype = torch .int )
28+ for i in range (pack_size ):
29+ unpacked_x [i ::pack_size ] = (x >> (i * bits )) & (mask )
30+ return unpacked_x
31+
32+
33+ def pack (x , dim , bits = 4 ):
34+ return pack_row (x , bits ) if dim == 0 else pack_col (x , bits )
35+
36+ def pack_col (x , bits ):
37+ mask = 2 ** bits - 1
38+ pack_size = 32 // bits
39+ packed_x = torch .zeros ((x .shape [0 ], x .shape [1 ]// pack_size ), dtype = torch .int )
40+ for i in range (pack_size ):
41+ packed_x |= (x [:, i ::pack_size ] & mask ) << (i * bits )
42+ return packed_x
43+
44+ def pack_row (x , bits ):
45+ mask = 2 ** bits - 1
46+ pack_size = 32 // bits
47+ packed_x = torch .zeros ((x .shape [0 ]// pack_size , x .shape [1 ]), dtype = torch .int )
48+ for i in range (pack_size ):
49+ packed_x |= (x [i ::pack_size ] & mask ) << (i * bits )
50+ return packed_x
1351
1452class Weights :
1553 def __init__ (
@@ -101,7 +139,7 @@ def get_partial_sharded(self, tensor_name: str, dim: int):
101139 tensor = tensor .to (device = self .device )
102140 return tensor
103141
104- def get_sharded (self , tensor_name : str , dim : int ):
142+ def get_sharded (self , tensor_name : str , dim : int , perm = None , packed = False ):
105143 filename , tensor_name = self .get_filename (tensor_name )
106144 f = self ._get_handle (filename )
107145 slice_ = f .get_slice (tensor_name )
@@ -110,17 +148,53 @@ def get_sharded(self, tensor_name: str, dim: int):
110148 assert (
111149 size % world_size == 0
112150 ), f"The choosen size { size } is not compatible with sharding on { world_size } shards"
113- return self .get_partial_sharded (tensor_name , dim )
151+ if perm is None :
152+ return self .get_partial_sharded (tensor_name , dim )
153+ else :
154+ return self .get_shuffle_sharded (tensor_name , dim , perm , packed )
155+
156+ def get_shuffle_sharded (self , tensor_name : str , dim : int , perm , packed : bool ):
157+ filename , tensor_name = self .get_filename (tensor_name )
158+ world_size = self .process_group .size ()
159+ rank = self .process_group .rank ()
114160
115- def get_multi_weights_col (self , prefixes : List [str ], quantize : str , dim : int ):
161+ f = self ._get_handle (filename )
162+ tensor = f .get_tensor (tensor_name )
163+ perm = perm .to (device = tensor .device )
164+ size = tensor .shape [dim ]
165+ block_size = size // world_size
166+ start = rank * block_size
167+ stop = (rank + 1 ) * block_size
168+
169+ # TODO: pack-unpack on cuda to speed up this part
170+ if dim == 0 :
171+ if packed :
172+ tensor = pack (unpack (tensor , dim )[perm ], dim )[start :stop ]
173+ else :
174+ tensor = tensor [perm ][start :stop ]
175+ elif dim == 1 :
176+ if packed :
177+ tensor = pack (unpack (tensor , dim )[:, perm ], dim )[:, start :stop ]
178+ else :
179+ tensor = tensor [:, perm ][:, start :stop ]
180+ else :
181+ raise NotImplementedError ("Let's make that generic when needed" )
182+ # Special case for gptq which shouldn't convert
183+ # u4 which are disguised as int32
184+ if tensor .dtype != torch .int32 :
185+ tensor = tensor .to (dtype = self .dtype )
186+ tensor = tensor .to (device = self .device )
187+ return tensor
188+
189+ def get_multi_weights_col (self , prefixes : List [str ], quantize : str , dim : int , col_perm = None ):
116190 if quantize == "gptq" :
117191 try :
118- qweight = torch .cat ([self .get_sharded (f"{ p } .qweight" , dim = 1 ) for p in prefixes ], dim = 1 )
192+ qweight = torch .cat ([self .get_sharded (f"{ p } .qweight" , dim = 1 , perm = col_perm , packed = False ) for p in prefixes ], dim = 1 )
119193 except RuntimeError :
120194 raise RuntimeError ("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" )
121195
122- qzeros = torch .cat ([self .get_sharded (f"{ p } .qzeros" , dim = 1 ) for p in prefixes ], dim = 1 )
123- scales = torch .cat ([self .get_sharded (f"{ p } .scales" , dim = 1 ) for p in prefixes ], dim = 1 )
196+ qzeros = torch .cat ([self .get_sharded (f"{ p } .qzeros" , dim = 1 , perm = col_perm , packed = True ) for p in prefixes ], dim = 1 )
197+ scales = torch .cat ([self .get_sharded (f"{ p } .scales" , dim = 1 , perm = col_perm , packed = False ) for p in prefixes ], dim = 1 )
124198 w = [self .get_tensor (f"{ p } .g_idx" ) for p in prefixes ]
125199 for w2 in w [1 :]:
126200 torch .testing .assert_close (w2 , w [0 ])
@@ -141,39 +215,36 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
141215 weight = torch .cat (w , dim = dim )
142216 return weight
143217
144- def get_multi_weights_row (self , prefix : str , quantize : str ):
218+ def get_multi_weights_row (self , prefix : str , quantize : str , row_perm = None , noshard = False ):
145219 if quantize == "gptq" :
146220 bits , groupsize = self ._get_gptq_params ()
147221
148- use_exllama = bits == 4
149-
150- if self .process_group .size () > 1 :
151- g_idx = self .get_tensor (f"{ prefix } .g_idx" )
152- if g_idx is not None :
153- if not torch .equal (g_idx .cpu (), torch .tensor ([i // groupsize for i in range (g_idx .shape [0 ])], dtype = torch .int32 )) and not (g_idx == 0 ).all ():
154- # Exllama implementation does not support row tensor parallelism with act-order, as
155- # it would require to reorder input activations that are split unto several GPUs
156- use_exllama = False
222+ from text_generation_server .utils .layers import HAS_EXLLAMA
223+ is_preshuffle = (row_perm != None )
224+ is_masked_matmul = noshard
225+ assert (is_preshuffle != is_masked_matmul
226+ or not (is_preshuffle or is_masked_matmul )), f"TP-aware optimization can't both be enabled at the same time { is_preshuffle = } , { is_masked_matmul = } "
227+ use_exllama = (bits == 4 ) and HAS_EXLLAMA or (is_preshuffle or is_masked_matmul )
228+ if self .process_group .rank == 0 :
229+ if use_exllama :
230+ logger .info (f"Using exllama kernels for row { prefix } " )
231+ else :
232+ logger .warning (
233+ "Exllama GPTQ cuda kernels (which are faster) could have been used, but are disabled via the DISABLE_EXLLAMA env var,"
234+ " or not currently installed, try using BUILD_EXTENSIONS=True"
235+ )
157236
158237 try :
159- qweight = self .get_sharded (f"{ prefix } .qweight" , dim = 0 )
238+ qweight = self .get_sharded (f"{ prefix } .qweight" ,
239+ dim = 0 ,
240+ perm = row_perm if use_exllama else None ,
241+ packed = True ,
242+ ) if not is_masked_matmul else self .get_tensor (f"{ prefix } .qweight" )
160243 except RuntimeError :
161244 raise RuntimeError ("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" )
162245
163- from text_generation_server .utils .layers import HAS_EXLLAMA
164- if use_exllama :
165- use_exllama = HAS_EXLLAMA
166- if self .process_group .rank == 0 :
167- if use_exllama :
168- logger .info (f"Using exllama kernels for row { prefix } " )
169- else :
170- logger .warning (
171- "Exllama GPTQ cuda kernels (which are faster) could have been used, but are disabled via the DISABLE_EXLLAMA env var,"
172- " or not currently installed, try using BUILD_EXTENSIONS=True"
173- )
174-
175246 if use_exllama :
176- if groupsize >= 0 :
247+ if groupsize >= 0 and not is_masked_matmul :
177248 # Exllama reorders the weights in advance and the activations on the fly, thus
178249 # the scales and zero-points do not need to be reordered.
179250 qzeros = self .get_sharded (f"{ prefix } .qzeros" , dim = 0 )
@@ -183,7 +254,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str):
183254 scales = self .get_tensor (f"{ prefix } .scales" )
184255
185256 # For tp > 1, at this point we know we do not use act-order
186- if self .process_group .size () == 1 :
257+ if ( self .process_group .size () == 1 or is_masked_matmul ) and not is_preshuffle :
187258 g_idx = self .get_tensor (f"{ prefix } .g_idx" )
188259 else :
189260 g_idx = None
@@ -197,7 +268,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str):
197268
198269 weight = (qweight , qzeros , scales , g_idx , bits , groupsize , use_exllama )
199270 else :
200- weight = self .get_sharded (f"{ prefix } .weight" , dim = 1 )
271+ weight = self .get_sharded (f"{ prefix } .weight" , dim = 1 ) if not noshard else self . get_tensor ( f" { prefix } .weight" )
201272 return weight
202273
203274 def _get_gptq_params (self ) -> Tuple [int , int ]:
0 commit comments