Skip to content

Commit 9b9b737

Browse files
committed
add test
1 parent fd9449b commit 9b9b737

File tree

3 files changed

+74
-1
lines changed

3 files changed

+74
-1
lines changed

model_compression_toolkit/core/pytorch/reader/graph_builders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def nodes_builder(model: GraphModule,
221221
elif hasattr(torch.Tensor, node.target):
222222
node_type = getattr(torch.Tensor, node.target)
223223
if node_type==torch._C._TensorBase.to:
224-
Logger.critical(f"The call method '{node.target}' is not supported. Please consider moving \"torch.Tensor.to\" operations to init code.") # pragma: no cover
224+
Logger.critical(f"The call method \"to\" is not supported. Please consider moving \"torch.Tensor.to\" operations to init code.") # pragma: no cover
225225
else:
226226
Logger.critical(f"The call method '{node.target}' in {node} is not supported.") # pragma: no cover
227227

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import torch
16+
from torch import nn
17+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
18+
AttachTpcToPytorch
19+
20+
import pytest
21+
from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
22+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
23+
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
24+
25+
26+
def data_gen():
27+
yield [torch.rand(1, 10, 28, 32)]
28+
29+
30+
class Model(nn.Module):
31+
def __init__(self):
32+
super().__init__()
33+
self.conv1 = nn.Conv2d(10, 20, kernel_size=(5, 4))
34+
self.conv2 = nn.Conv2d(20, 15, kernel_size=(4, 6), groups=5)
35+
36+
def forward(self, x):
37+
x = self.conv1(x)
38+
x = self.conv2(x).to(x.device)
39+
return x
40+
41+
42+
def test_assert_to_operation(minimal_tpc):
43+
Model()(next(data_gen())[0])
44+
45+
fw_impl = PytorchImplementation()
46+
fw_info = DEFAULT_PYTORCH_INFO
47+
model = Model()
48+
49+
with pytest.raises(Exception, match=f"The call method \"to\" is not supported. Please consider moving \"torch.Tensor.to\" operations to init code."):
50+
_ = read_model_to_graph(model,
51+
data_gen,
52+
fqc=AttachTpcToPytorch().attach(minimal_tpc),
53+
fw_info=fw_info,
54+
fw_impl=fw_impl)
55+
56+
57+
58+
59+

0 commit comments

Comments
 (0)