Skip to content

Commit 7561169

Browse files
authored
Merge pull request #3 from donn/main
Tweaks for running on Linux
2 parents a1263be + 3b831da commit 7561169

File tree

5 files changed

+38
-14
lines changed

5 files changed

+38
-14
lines changed

.gitignore

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ backup/
1818
*.pdf
1919
# python cache
2020
__pycache__/
21-
22-
23-
24-
21+
venv/
22+
# ides
23+
.vscode/
24+
.idea/

BitNetMCU_MNIST_dll.c

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,15 @@
1313
* @return The result of the inference.
1414
*/
1515

16+
uint32_t BitMnistInference(int8_t *input);
17+
1618
#ifdef _DLL
17-
__declspec(dllexport) uint32_t Inference(int8_t *input) {
19+
#ifdef WIN32
20+
#define EXPORT __declspec(dllexport)
21+
#else
22+
#define EXPORT __attribute__((visibility("default")))
23+
#endif
24+
EXPORT uint32_t Inference(int8_t *input) {
1825
return BitMnistInference(input);
1926
}
2027
#endif
@@ -63,4 +70,4 @@ uint32_t BitMnistInference(int8_t *input) {
6370
return ReLUNorm(layer_out, layer_in, L3_outgoing_weights);
6471
#endif
6572

66-
}
73+
}

Makefile

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
SOURCES = BitNetMCU_MNIST_dll.c BitNetMCU_inference.c
2+
HEADERS = BitNetMCU_model.h BitNetMCU_inference.h
3+
DLL = Bitnet_inf.dll
4+
5+
$(DLL): $(SOURCES) $(HEADERS)
6+
cc -fPIC -shared -o $@ -D _DLL $<
7+
8+
.PHONY: clean
9+
clean:
10+
rm -f $(DLL)

exportquant.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,27 +69,30 @@ def export_to_hfile(quantized_model, filename, runname):
6969

7070
print(f'Layer: {layer} Quantization type: <{quantization_type}>, Bits per weight: {bpw}, Num. incoming: {incoming_weights}, Num outgoing: {outgoing_weights}')
7171

72+
data_type = np.uint32
73+
7274
if quantization_type == 'Binary':
7375
encoded_weights = np.where(weights == -1, 0, 1)
7476
QuantID = 1
7577
elif quantization_type == '2bitsym': # encoding -1.5 -> 11, -0.5 -> 10, 0.5 -> 00, 1.5 -> 01 (one complement with offset)
76-
encoded_weights = ((weights < 0).astype(int) << 1) | (np.floor(np.abs(weights))).astype(int) # use bitwise operations to encode the weights
78+
encoded_weights = ((weights < 0).astype(data_type) << 1) | (np.floor(np.abs(weights))).astype(data_type) # use bitwise operations to encode the weights
7779
QuantID = 2
7880
elif quantization_type == '4bitsym':
79-
encoded_weights = ((weights < 0).astype(int) << 3) | (np.floor(np.abs(weights))).astype(int) # use bitwise operations to encode the weights
81+
encoded_weights = ((weights < 0).astype(data_type) << 3) | (np.floor(np.abs(weights))).astype(data_type) # use bitwise operations to encode the weights
8082
QuantID = 4
8183
elif quantization_type == 'FP130': # FP1.3.0 encoding (sign * 2^exp)
82-
encoded_weights = ((weights < 0).astype(int) << 3) | (np.floor(np.log2(np.abs(weights)))).astype(int)
84+
encoded_weights = ((weights < 0).astype(data_type) << 3) | (np.floor(np.log2(np.abs(weights)))).astype(data_type)
8385
QuantID = 16 + 4
8486
else:
8587
print(f'Skipping layer {layer} with quantization type {quantization_type} and {bpw} bits per weight. Quantization type not supported.')
8688

8789
# pack bits into 32 bit words
8890
weight_per_word = 32 // bpw
8991
reshaped_array = encoded_weights.reshape(-1, weight_per_word)
90-
bit_positions = 32 - bpw - np.arange(weight_per_word) * bpw
91-
packed_weights = np.bitwise_or.reduce(reshaped_array << bit_positions, axis=1).view(np.uint32)
92-
92+
93+
bit_positions = 32 - bpw - np.arange(weight_per_word, dtype=data_type) * bpw
94+
packed_weights = np.bitwise_or.reduce(reshaped_array << bit_positions, axis=1).view(data_type)
95+
9396
# print(f'weights: {weights.shape} {weights.flatten()[0:16]}')
9497
# print(f'Encoded weights: {encoded_weights.shape} {encoded_weights.flatten()[0:16]}')
9598
# print(f'Packed weights: {packed_weights.shape} {", ".join(map(lambda x: hex(x), packed_weights.flatten()[0:4]))}')
@@ -338,4 +341,6 @@ def plot_weight_histograms(quantized_model):
338341
# export the quantized model to a header file
339342
# export_to_hfile(quantized_model, f'{exportfolder}/{runname}.h')
340343
export_to_hfile(quantized_model, f'BitNetMCU_model.h',runname)
341-
plt.show()
344+
345+
if showplots:
346+
plt.show()

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
torch
22
torchvision
33
numpy
4-
PyYAML
4+
PyYAML
5+
tensorboard
6+
matplotlib

0 commit comments

Comments
 (0)