Skip to content

Commit 9991a60

Browse files
Add e2e test
1 parent f2c2e9d commit 9991a60

File tree

3 files changed

+107
-8
lines changed

3 files changed

+107
-8
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright 2026 Sony Semiconductor Solutions, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import numpy as np
16+
import model_compression_toolkit as mct
17+
import torch
18+
import torch.nn as nn
19+
import pytest
20+
21+
22+
class ScalarModel(nn.Module):
23+
24+
def __init__(self, name):
25+
super().__init__()
26+
self.name = name
27+
self.scalar = nn.Parameter(2.0 * torch.ones([])) # Scalar
28+
29+
def forward(self, x):
30+
31+
if self.name == 'add':
32+
const = torch.add(self.scalar, 1)
33+
elif self.name == 'relu6':
34+
const = torch.nn.functional.relu6(self.scalar)
35+
elif self.name == 'relu':
36+
const = torch.relu(self.scalar)
37+
elif self.name == 'sigmoid':
38+
const = torch.sigmoid(self.scalar)
39+
elif self.name == 'eq':
40+
const = torch.eq(self.scalar, 1)
41+
elif self.name == 'leaky_relu':
42+
const = torch.nn.functional.leaky_relu(self.scalar)
43+
elif self.name == 'mul':
44+
const = torch.mul(self.scalar, 1)
45+
elif self.name == 'sub':
46+
const = torch.sub(self.scalar, 1)
47+
elif self.name == 'div':
48+
const = torch.div(self.scalar, 1)
49+
elif self.name == 'softmax':
50+
const = torch.nn.functional.softmax(self.scalar)
51+
elif self.name == 'tanh':
52+
const = torch.tanh(self.scalar)
53+
elif self.name == 'negative':
54+
const = torch.negative(self.scalar)
55+
elif self.name == 'abs':
56+
const = torch.abs(self.scalar)
57+
elif self.name == 'sqrt':
58+
const = torch.sqrt(self.scalar)
59+
elif self.name == 'sum':
60+
const = torch.sum(self.scalar)
61+
elif self.name == 'rsqrt':
62+
const = torch.rsqrt(self.scalar)
63+
elif self.name == 'silu':
64+
const = torch.nn.functional.silu(self.scalar)
65+
elif self.name == 'hardswish':
66+
const = torch.nn.functional.hardswish(self.scalar)
67+
elif self.name == 'hardsigmoid':
68+
const = torch.nn.functional.hardsigmoid(self.scalar)
69+
elif self.name == 'pow':
70+
const = torch.pow(self.scalar, 1)
71+
elif self.name == 'gelu':
72+
const = torch.nn.functional.gelu(self.scalar)
73+
elif self.name == 'cos':
74+
const = torch.cos(self.scalar)
75+
elif self.name == 'sin':
76+
const = torch.sin(self.scalar)
77+
elif self.name == 'exp':
78+
const = torch.exp(self.scalar)
79+
80+
y = x + const
81+
return y
82+
83+
def representative_data_gen():
84+
yield [np.random.random((1, 3, 8, 8))]
85+
86+
@pytest.mark.parametrize("layer_name", [
87+
'add', 'relu6', 'relu', 'sigmoid', 'eq', 'leaky_relu', 'mul', 'sub', 'div', 'softmax',
88+
'tanh', 'negative', 'abs', 'sqrt', 'sum', 'rsqrt', 'silu', 'hardswish', 'hardsigmoid',
89+
'pow', 'gelu', 'cos', 'sin', 'exp'
90+
])
91+
def test_scalar_layer(layer_name):
92+
93+
float_model = ScalarModel(name=layer_name)
94+
95+
tpc = mct.get_target_platform_capabilities("6.0")
96+
quantized_model, _ = mct.ptq.pytorch_post_training_quantization(float_model,
97+
representative_data_gen=representative_data_gen,
98+
target_platform_capabilities=tpc)

tests_pytest/pytorch_tests/integration_tests/core/test_model_collector_for_scalar.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ def fw_info_mock():
5353
class TestModelCollectorInit:
5454

5555
def test_init(self, fw_impl_mock, fw_info_mock):
56-
node0 = build_node('node0', output_shape=(1, 3, 2, 2)) # 4D tensor
57-
node1 = build_node('node1', output_shape=(3, 2)) # 2D tensor
58-
node2 = build_node('node2', output_shape=(4,)) # 1D tensor
59-
node3 = build_node('node3', output_shape=()) # Scalar
56+
node0 = build_node('node0', output_shape=[[1, 3, 2, 2]]) # 4D tensor
57+
node1 = build_node('node1', output_shape=[[3, 2]]) # 2D tensor
58+
node2 = build_node('node2', output_shape=[[4]]) # 1D tensor
59+
node3 = build_node('node3', output_shape=[[]]) # Scalar
6060

6161
mock_nodes_list = [node0, node1, node2, node3]
6262
for node in mock_nodes_list:
@@ -88,10 +88,10 @@ def test_init(self, fw_impl_mock, fw_info_mock):
8888
class TestModelCollectorInfer:
8989

9090
def test_infer(self, fw_impl_mock, fw_info_mock):
91-
node0 = build_node('node0', output_shape=(1, 3, 2, 2)) # 4D tensor
92-
node1 = build_node('node1', output_shape=(3, 2)) # 2D tensor
93-
node2 = build_node('node2', output_shape=(4,)) # 1D tensor
94-
node3 = build_node('node3', output_shape=()) # scalar
91+
node0 = build_node('node0', output_shape=[[1, 3, 2, 2]]) # 4D tensor
92+
node1 = build_node('node1', output_shape=[[3, 2]]) # 2D tensor
93+
node2 = build_node('node2', output_shape=[[4]]) # 1D tensor
94+
node3 = build_node('node3', output_shape=[[]]) # scalar
9595

9696
mock_nodes_list = [node0, node1, node2, node3]
9797
for node in mock_nodes_list:

tests_pytest/pytorch_tests/unit_tests/core/test_create_stats_collector_for_node.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def node_mock():
4444

4545

4646
class TestCreateStatsCollectorForNode:
47+
4748
def test_create_stats_collector_for_node_conv(self, node_mock, fw_info_mock):
4849
node_mock.type = Conv2D
4950
node_mock.get_output_shapes_list.return_value = [[1, 3, 32, 32]]

0 commit comments

Comments
 (0)