@@ -69,27 +69,30 @@ def export_to_hfile(quantized_model, filename, runname):
6969
7070 print (f'Layer: { layer } Quantization type: <{ quantization_type } >, Bits per weight: { bpw } , Num. incoming: { incoming_weights } , Num outgoing: { outgoing_weights } ' )
7171
72+ data_type = np .uint32
73+
7274 if quantization_type == 'Binary' :
7375 encoded_weights = np .where (weights == - 1 , 0 , 1 )
7476 QuantID = 1
7577 elif quantization_type == '2bitsym' : # encoding -1.5 -> 11, -0.5 -> 10, 0.5 -> 00, 1.5 -> 01 (one complement with offset)
76- encoded_weights = ((weights < 0 ).astype (int ) << 1 ) | (np .floor (np .abs (weights ))).astype (int ) # use bitwise operations to encode the weights
78+ encoded_weights = ((weights < 0 ).astype (data_type ) << 1 ) | (np .floor (np .abs (weights ))).astype (data_type ) # use bitwise operations to encode the weights
7779 QuantID = 2
7880 elif quantization_type == '4bitsym' :
79- encoded_weights = ((weights < 0 ).astype (int ) << 3 ) | (np .floor (np .abs (weights ))).astype (int ) # use bitwise operations to encode the weights
81+ encoded_weights = ((weights < 0 ).astype (data_type ) << 3 ) | (np .floor (np .abs (weights ))).astype (data_type ) # use bitwise operations to encode the weights
8082 QuantID = 4
8183 elif quantization_type == 'FP130' : # FP1.3.0 encoding (sign * 2^exp)
82- encoded_weights = ((weights < 0 ).astype (int ) << 3 ) | (np .floor (np .log2 (np .abs (weights )))).astype (int )
84+ encoded_weights = ((weights < 0 ).astype (data_type ) << 3 ) | (np .floor (np .log2 (np .abs (weights )))).astype (data_type )
8385 QuantID = 16 + 4
8486 else :
8587 print (f'Skipping layer { layer } with quantization type { quantization_type } and { bpw } bits per weight. Quantization type not supported.' )
8688
8789 # pack bits into 32 bit words
8890 weight_per_word = 32 // bpw
8991 reshaped_array = encoded_weights .reshape (- 1 , weight_per_word )
90- bit_positions = 32 - bpw - np .arange (weight_per_word ) * bpw
91- packed_weights = np .bitwise_or .reduce (reshaped_array << bit_positions , axis = 1 ).view (np .uint32 )
92-
92+
93+ bit_positions = 32 - bpw - np .arange (weight_per_word , dtype = data_type ) * bpw
94+ packed_weights = np .bitwise_or .reduce (reshaped_array << bit_positions , axis = 1 ).view (data_type )
95+
9396 # print(f'weights: {weights.shape} {weights.flatten()[0:16]}')
9497 # print(f'Encoded weights: {encoded_weights.shape} {encoded_weights.flatten()[0:16]}')
9598 # print(f'Packed weights: {packed_weights.shape} {", ".join(map(lambda x: hex(x), packed_weights.flatten()[0:4]))}')
@@ -338,4 +341,6 @@ def plot_weight_histograms(quantized_model):
338341 # export the quantized model to a header file
339342 # export_to_hfile(quantized_model, f'{exportfolder}/{runname}.h')
340343 export_to_hfile (quantized_model , f'BitNetMCU_model.h' ,runname )
341- plt .show ()
344+
345+ if showplots :
346+ plt .show ()
0 commit comments