Skip to content

Commit 13dd71f

Browse files
Clean up synthetic weights/inputs generation integration in pulp-nnx
1 parent 973cd49 commit 13dd71f

File tree

2 files changed

+35
-28
lines changed

2 files changed

+35
-28
lines changed

test/NnxTestClasses.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class NnxTestConf(BaseModel):
4848
has_norm_quant: bool
4949
has_bias: bool
5050
has_relu: bool
51+
synthetic_weights: bool
52+
synthetic_inputs: bool
5153

5254
@model_validator(mode="after") # type: ignore
5355
def check_valid_depthwise_channels(self) -> NnxTestConf:
@@ -116,6 +118,8 @@ def __init__(
116118
scale: Optional[torch.Tensor] = None,
117119
bias: Optional[torch.Tensor] = None,
118120
global_shift: Optional[torch.Tensor] = torch.Tensor([0]),
121+
synthetic_weights: Optional[bool] = False,
122+
synthetic_inputs: Optional[bool] = False,
119123
) -> None:
120124
self.conf = conf
121125
self.input = input
@@ -124,6 +128,8 @@ def __init__(
124128
self.scale = scale
125129
self.bias = bias
126130
self.global_shift = global_shift
131+
self.synthetic_weights = synthetic_weights
132+
self.synthetic_inputs = synthetic_inputs
127133

128134
def is_valid(self) -> bool:
129135
return all(
@@ -243,20 +249,30 @@ def from_conf(
243249
bias_shape = (1, conf.out_channel, 1, 1)
244250

245251
if input is None:
246-
input = NnxTestGenerator._random_data(
247-
_type=conf.in_type,
248-
shape=input_shape,
249-
)
252+
if conf.synthetic_inputs:
253+
inputs = torch.zeros((1, conf.in_channel, conf.in_height, conf.in_width), dtype=torch.int64)
254+
for i in range(conf.in_channel):
255+
inputs[:, i,0,0] = i
256+
else:
257+
input = NnxTestGenerator._random_data(
258+
_type=conf.in_type,
259+
shape=input_shape,
260+
)
250261

251262
if weight is None:
252-
weight_mean = NnxTestGenerator._DEFAULT_WEIGHT_MEAN
253-
weight_std = NnxTestGenerator._DEFAULT_WEIGHT_STDEV * (1<<(conf.weight_type._bits-1)-1)
254-
weight = NnxTestGenerator._random_data_normal(
255-
mean = weight_mean,
256-
std = weight_std,
257-
_type=conf.weight_type,
258-
shape=weight_shape,
259-
)
263+
if conf.synthetic_weights:
264+
weight = torch.zeros((conf.out_channel, 1 if conf.depthwise else conf.in_channel, conf.kernel_shape.height, conf.kernel_shape.width), dtype=torch.int64)
265+
for i in range(0, min(weight.shape[0], weight.shape[1])):
266+
weight[i,i,0,0] = 1
267+
else:
268+
weight_mean = NnxTestGenerator._DEFAULT_WEIGHT_MEAN
269+
weight_std = NnxTestGenerator._DEFAULT_WEIGHT_STDEV * (1<<(conf.weight_type._bits-1)-1)
270+
weight = NnxTestGenerator._random_data_normal(
271+
mean = weight_mean,
272+
std = weight_std,
273+
_type=conf.weight_type,
274+
shape=weight_shape,
275+
)
260276

261277
if conf.has_norm_quant:
262278
if scale is None:
@@ -306,6 +322,8 @@ def from_conf(
306322
scale=scale,
307323
bias=bias,
308324
global_shift=global_shift,
325+
synthetic_inputs=conf.synthetic_inputs,
326+
synthetic_weights=conf.synthetic_weights,
309327
)
310328

311329
@staticmethod
@@ -361,7 +379,10 @@ def generate(self, test_name: str, test: NnxTest):
361379
weight_type = test.conf.weight_type
362380
weight_bits = weight_type._bits
363381
assert weight_bits > 1 and weight_bits <= 8
364-
weight_offset = -(2 ** (weight_bits - 1))
382+
if test.synthetic_weights:
383+
weight_offset = 0
384+
else:
385+
weight_offset = -(2 ** (weight_bits - 1))
365386
weight_out_ch, weight_in_ch, weight_ks_h, weight_ks_w = test.weight.shape
366387
weight_data: np.ndarray = test.weight.numpy() - weight_offset
367388
weight_init = self.weightEncode(

test/testgen.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -86,21 +86,7 @@ def test_gen(
8686
exit(-1)
8787

8888
test_conf = nnxTestConfCls.model_validate(test_conf_dict)
89-
if test_conf_dict['synthetic_weights']:
90-
import torch
91-
weight = torch.zeros((test_conf.out_channel, 1 if test_conf.depthwise else test_conf.in_channel, test_conf.kernel_shape.height, test_conf.kernel_shape.width), dtype=torch.int64)
92-
for i in range(0, min(weight.shape[0], weight.shape[1])):
93-
weight[i,i,0,0] = 1
94-
else:
95-
weight = None
96-
if test_conf_dict['synthetic_inputs']:
97-
import torch
98-
inputs = torch.zeros((1, test_conf.in_channel, test_conf.in_height, test_conf.in_width), dtype=torch.int64)
99-
for i in range(test_conf.in_channel):
100-
inputs[:, i,0,0] = i
101-
else:
102-
inputs = None
103-
test = NnxTestGenerator.from_conf(test_conf, verbose=args.print_tensors, weight=weight, input=inputs)
89+
test = NnxTestGenerator.from_conf(test_conf, verbose=args.print_tensors)
10490
if not args.skip_save:
10591
test.save(args.test_dir)
10692
if args.headers:

0 commit comments

Comments
 (0)