Skip to content

Commit 1e19f46

Browse files
Maybe fix 3x3?
1 parent b55e714 commit 1e19f46

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

test/NeurekaMemoryLayout.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222

2323

2424
class NeurekaMemoryLayout:
25-
_WEIGHT_BANDWIDTH = 256
25+
_WEIGHT_BANDWIDTH_1x1 = 256
26+
_WEIGHT_BANDWIDTH_3x3 = 288
2627
_CIN_SUBTILE_1x1 = 32
2728
_CIN_SUBTILE_3x3 = 32
2829

@@ -77,27 +78,29 @@ def weightEncode(
7778
weight = weight.reshape(-1, height * width * cinSubtile)
7879
# Pad only the last dimension to weight bandwidth size
7980
# (-1, Weight Bandwidth)
81+
print("DEBUG", weight.shape)
8082
weight = np.pad(
8183
weight,
82-
((0, 0), (0, NeurekaMemoryLayout._WEIGHT_BANDWIDTH - weight.shape[-1])),
84+
((0, 0), (0, NeurekaMemoryLayout._WEIGHT_BANDWIDTH_3x3 - weight.shape[-1])),
8385
"constant",
8486
constant_values=0,
8587
)
88+
weightBandwidthBytes = int(np.ceil(NeurekaMemoryLayout._WEIGHT_BANDWIDTH_3x3 / 8))
8689
elif height == 1 and width == 1:
8790
# (cout * cinMajor, Bits * cinSubtile)
8891
weight = weight.reshape(-1, bits * cinSubtile)
8992
# Pad only the last dimension to weight bandwidth size
9093
# (-1, Weight Bandwidth)
9194
weight = np.pad(
9295
weight,
93-
((0, 0), (0, NeurekaMemoryLayout._WEIGHT_BANDWIDTH - weight.shape[-1])),
96+
((0, 0), (0, NeurekaMemoryLayout._WEIGHT_BANDWIDTH_1x1 - weight.shape[-1])),
9497
"constant",
9598
constant_values=0,
9699
)
100+
weightBandwidthBytes = int(np.ceil(NeurekaMemoryLayout._WEIGHT_BANDWIDTH_1x1 / 8))
97101

98102
# Prepare for packing
99103
# (-1, Weight Bandwidth Bytes, 8)
100-
weightBandwidthBytes = int(np.ceil(NeurekaMemoryLayout._WEIGHT_BANDWIDTH / 8))
101104
weight = np.stack(np.split(weight, weightBandwidthBytes, axis=-1), axis=-2)
102105

103106
# Pack bits

0 commit comments

Comments
 (0)