Skip to content

Commit b4d7cd4

Browse files
authored
Fix 32bit scale (#5)
* Fix functional model saturating casts * Add more intermediate result prints * Fix global_shift calculation * Add tests with 32bit scale
1 parent 1250806 commit b4d7cd4

File tree

8 files changed

+115
-10
lines changed

8 files changed

+115
-10
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
- Support for kernels without normalization and quantization for NE16
99
- isort check
1010
- publication citation
11+
- support 32bit scale
1112

1213
### Changed
1314

1415
- `ne16_task_init` got split into smaller parts: `ne16_task_init`, `ne16_task_set_op_to_conv`, `ne16_task_set_weight_offset`, `ne16_task_set_bits`, `ne16_task_set_norm_quant`
1516
- strides in `ne16_task_set_strides`, `ne16_task_set_dims`, and `ne16_task_set_ptrs` are now strides between consecutive elements in that dimension
1617
- `ne16_task_queue_size` is now `NE16_TASK_QUEUE_SIZE`
18+
- `ne16_task_set_ptrs` split into `ne16_task_set_ptrs_conv` and `ne16_task_set_ptrs_norm_quant`
1719

1820
### Removed
1921

ne16/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
- [ ] Scale type
2929
- [x] uint8
3030
- [ ] uint16
31-
- [ ] uint32
31+
- [x] uint32
3232
- [x] Bias type
3333
- [x] int32
3434
- [ ] Weight type

neureka/README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,16 @@ Github repo [link](https://github.com/siracusa-soc/ne).
1616
- [x] Bias (w/ and w/o)
1717
- [ ] Per-channel shift
1818
- [x] Per-layer shift
19-
- [ ] Rounding
2019
- [x] Input type
2120
- [x] uint8
2221
- [x] int8
2322
- [x] Output type
2423
- [x] int8
2524
- [x] uint8 (only w/ Relu)
2625
- [x] int32
27-
- [ ] Scale type
26+
- [x] Scale type
2827
- [x] uint8
29-
- [ ] uint32
28+
- [x] uint32
3029
- [x] Bias type
3130
- [x] int32
3231
- [ ] Weight type

test/NeuralEngineFunctionalModel.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,24 +28,34 @@ def _norm_quant(
2828
bias_type: Optional[IntegerType],
2929
has_bias: bool,
3030
has_relu: bool,
31+
verbose: bool,
3132
) -> torch.Tensor:
3233
# Scale accumulators are in 48bit, so keeping the data in 64bit
3334
tensor = tensor * scale
3435
assert tensor.dtype == torch.int64
3536

37+
if verbose:
38+
print("INTERMEDIATE RESULTS (after scale):")
39+
print(tensor)
40+
3641
if has_bias:
3742
assert bias is not None
3843
assert bias_type is not None
39-
# Saturating cast to int32
44+
4045
tensor = NeuralEngineFunctionalModel._cast(
41-
tensor, bias_type, saturate=True
46+
tensor, bias_type, saturate=False
4247
).type(torch.int32)
4348

4449
tensor = tensor + bias
50+
4551
tensor = NeuralEngineFunctionalModel._cast(
46-
tensor, bias_type, saturate=False
52+
tensor, bias_type, saturate=True
4753
).type(torch.int32)
4854

55+
if verbose:
56+
print("INTERMEDIATE RESULTS (after bias):")
57+
print(tensor)
58+
4959
if has_relu:
5060
tensor = F.relu(tensor)
5161

@@ -118,6 +128,7 @@ def convolution(
118128
bias_type,
119129
has_bias,
120130
has_relu,
131+
verbose,
121132
)
122133

123134
return output

test/NnxTestClasses.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,16 +254,22 @@ def from_conf(
254254
).type(torch.int32)
255255
if global_shift is None:
256256
global_shift = torch.Tensor([0]).type(torch.int32)
257+
conv_kwargs = {
258+
**conf.__dict__,
259+
"out_type": NeuralEngineFunctionalModel.ACCUMULATOR_TYPE,
260+
}
257261
output = NeuralEngineFunctionalModel().convolution(
258262
input,
259263
weight,
260264
scale,
261265
bias,
262266
global_shift,
263-
verbose=verbose,
264-
**conf.__dict__,
267+
verbose=False,
268+
**conv_kwargs,
269+
)
270+
global_shift = NnxTestGenerator._calculate_global_shift(
271+
output, conf.out_type
265272
)
266-
NnxTestGenerator._calculate_global_shift(output, conf.out_type)
267273

268274
output = NeuralEngineFunctionalModel().convolution(
269275
input, weight, scale, bias, global_shift, verbose=verbose, **conf.__dict__

test/tests/test_116/conf.json

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
{
2+
"in_height": 3,
3+
"in_width": 3,
4+
"in_channel": 2,
5+
"out_channel": 2,
6+
"padding": {
7+
"top": 0,
8+
"bottom": 0,
9+
"left": 0,
10+
"right": 0
11+
},
12+
"kernel_shape": {
13+
"height": 1,
14+
"width": 1
15+
},
16+
"depthwise": false,
17+
"stride": {
18+
"height": 1,
19+
"width": 1
20+
},
21+
"in_type": "int8",
22+
"out_type": "int8",
23+
"weight_type": "int8",
24+
"scale_type": "uint32",
25+
"bias_type": "int32",
26+
"has_norm_quant": true,
27+
"has_bias": true,
28+
"has_relu": false
29+
}

test/tests/test_117/conf.json

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
{
2+
"in_height": 10,
3+
"in_width": 10,
4+
"in_channel": 10,
5+
"out_channel": 10,
6+
"padding": {
7+
"top": 0,
8+
"bottom": 0,
9+
"left": 0,
10+
"right": 0
11+
},
12+
"kernel_shape": {
13+
"height": 1,
14+
"width": 1
15+
},
16+
"depthwise": false,
17+
"stride": {
18+
"height": 1,
19+
"width": 1
20+
},
21+
"in_type": "uint8",
22+
"out_type": "int8",
23+
"weight_type": "int8",
24+
"scale_type": "uint32",
25+
"bias_type": "int32",
26+
"has_norm_quant": true,
27+
"has_bias": true,
28+
"has_relu": false
29+
}

test/tests/test_118/conf.json

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
{
2+
"in_height": 10,
3+
"in_width": 10,
4+
"in_channel": 128,
5+
"out_channel": 128,
6+
"padding": {
7+
"top": 0,
8+
"bottom": 0,
9+
"left": 0,
10+
"right": 0
11+
},
12+
"kernel_shape": {
13+
"height": 1,
14+
"width": 1
15+
},
16+
"depthwise": false,
17+
"stride": {
18+
"height": 1,
19+
"width": 1
20+
},
21+
"in_type": "uint8",
22+
"out_type": "int8",
23+
"weight_type": "int8",
24+
"scale_type": "uint32",
25+
"bias_type": "int32",
26+
"has_norm_quant": true,
27+
"has_bias": true,
28+
"has_relu": false
29+
}

0 commit comments

Comments
 (0)