@@ -48,6 +48,8 @@ class NnxTestConf(BaseModel):
4848 has_norm_quant : bool
4949 has_bias : bool
5050 has_relu : bool
51+ synthetic_weights : bool
52+ synthetic_inputs : bool
5153
5254 @model_validator (mode = "after" ) # type: ignore
5355 def check_valid_depthwise_channels (self ) -> NnxTestConf :
@@ -116,6 +118,8 @@ def __init__(
116118 scale : Optional [torch .Tensor ] = None ,
117119 bias : Optional [torch .Tensor ] = None ,
118120 global_shift : Optional [torch .Tensor ] = torch .Tensor ([0 ]),
121+ synthetic_weights : Optional [bool ] = False ,
122+ synthetic_inputs : Optional [bool ] = False ,
119123 ) -> None :
120124 self .conf = conf
121125 self .input = input
@@ -124,6 +128,8 @@ def __init__(
124128 self .scale = scale
125129 self .bias = bias
126130 self .global_shift = global_shift
131+ self .synthetic_weights = synthetic_weights
132+ self .synthetic_inputs = synthetic_inputs
127133
128134 def is_valid (self ) -> bool :
129135 return all (
@@ -243,20 +249,30 @@ def from_conf(
243249 bias_shape = (1 , conf .out_channel , 1 , 1 )
244250
245251 if input is None :
246- input = NnxTestGenerator ._random_data (
247- _type = conf .in_type ,
248- shape = input_shape ,
249- )
252+ if conf .synthetic_inputs :
253+ inputs = torch .zeros ((1 , conf .in_channel , conf .in_height , conf .in_width ), dtype = torch .int64 )
254+ for i in range (conf .in_channel ):
255+ inputs [:, i ,0 ,0 ] = i
256+ else :
257+ input = NnxTestGenerator ._random_data (
258+ _type = conf .in_type ,
259+ shape = input_shape ,
260+ )
250261
251262 if weight is None :
252- weight_mean = NnxTestGenerator ._DEFAULT_WEIGHT_MEAN
253- weight_std = NnxTestGenerator ._DEFAULT_WEIGHT_STDEV * (1 << (conf .weight_type ._bits - 1 )- 1 )
254- weight = NnxTestGenerator ._random_data_normal (
255- mean = weight_mean ,
256- std = weight_std ,
257- _type = conf .weight_type ,
258- shape = weight_shape ,
259- )
263+ if conf .synthetic_weights :
264+ weight = torch .zeros ((conf .out_channel , 1 if conf .depthwise else conf .in_channel , conf .kernel_shape .height , conf .kernel_shape .width ), dtype = torch .int64 )
265+ for i in range (0 , min (weight .shape [0 ], weight .shape [1 ])):
266+ weight [i ,i ,0 ,0 ] = 1
267+ else :
268+ weight_mean = NnxTestGenerator ._DEFAULT_WEIGHT_MEAN
269+ weight_std = NnxTestGenerator ._DEFAULT_WEIGHT_STDEV * (1 << (conf .weight_type ._bits - 1 )- 1 )
270+ weight = NnxTestGenerator ._random_data_normal (
271+ mean = weight_mean ,
272+ std = weight_std ,
273+ _type = conf .weight_type ,
274+ shape = weight_shape ,
275+ )
260276
261277 if conf .has_norm_quant :
262278 if scale is None :
@@ -306,6 +322,8 @@ def from_conf(
306322 scale = scale ,
307323 bias = bias ,
308324 global_shift = global_shift ,
325+ synthetic_inputs = conf .synthetic_inputs ,
326+ synthetic_weights = conf .synthetic_weights ,
309327 )
310328
311329 @staticmethod
@@ -361,7 +379,10 @@ def generate(self, test_name: str, test: NnxTest):
361379 weight_type = test .conf .weight_type
362380 weight_bits = weight_type ._bits
363381 assert weight_bits > 1 and weight_bits <= 8
364- weight_offset = - (2 ** (weight_bits - 1 ))
382+ if test .synthetic_weights :
383+ weight_offset = 0
384+ else :
385+ weight_offset = - (2 ** (weight_bits - 1 ))
365386 weight_out_ch , weight_in_ch , weight_ks_h , weight_ks_w = test .weight .shape
366387 weight_data : np .ndarray = test .weight .numpy () - weight_offset
367388 weight_init = self .weightEncode (
0 commit comments