Skip to content

Commit 973cd49

Browse files
Tentative fix for QW<8 bit
This fixes layout + runtime for QW<8 bit. Tested only on pointwise and only on the special scenario of synthetic weights, for now.
1 parent 140fc2c commit 973cd49

File tree

2 files changed

+3
-10
lines changed

2 files changed

+3
-10
lines changed

neureka/hal/neureka_task.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ void neureka_task_set_strides(neureka_task_t *task, const uint32_t k_in,
169169
if (task->kernel_shape == 1) { // 1x1
170170
task->data.cfg.weights_stride.d0 = NEUREKA_WEIGHT_BANDWIDTH_BYTES_1x1;
171171
task->data.cfg.weights_stride.d1 =
172-
NEUREKA_WEIGHT_BANDWIDTH_BYTES_1x1 * num_k_in;
172+
(NEUREKA_WEIGHT_BANDWIDTH_BYTES_1x1 / 8) * task->qw * num_k_in;
173173
} else if (!task->depthwise) { // 3x3
174174
task->data.cfg.weights_stride.d0 = NEUREKA_WEIGHT_BANDWIDTH_BYTES_3x3;
175175
task->data.cfg.weights_stride.d1 =

test/NeurekaMemoryLayout.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,8 @@ def weightEncode(
8888
elif height == 1 and width == 1:
8989
# (cout * cinMajor, Bits * cinSubtile)
9090
weight = weight.reshape(-1, bits * cinSubtile)
91-
# Pad only the last dimension to weight bandwidth size
92-
# (-1, Weight Bandwidth)
93-
weight = np.pad(
94-
weight,
95-
((0, 0), (0, NeurekaMemoryLayout._WEIGHT_BANDWIDTH_1x1 - weight.shape[-1])),
96-
"constant",
97-
constant_values=0,
98-
)
99-
weightBandwidthBytes = int(np.ceil(NeurekaMemoryLayout._WEIGHT_BANDWIDTH_1x1 / 8))
91+
# No padding needed here
92+
weightBandwidthBytes = int(np.ceil(bits * cinSubtile / 8))
10093

10194
# Prepare for packing
10295
# (-1, Weight Bandwidth Bytes, 8)

0 commit comments

Comments
 (0)