Skip to content

Commit d10c510

Browse files
Missing corrections
1 parent 70731c9 commit d10c510

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

tests/pytorch_tests/test_activation_quantizer_holder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ def test_activation_quantization_holder_save_and_load(self):
7272
quantizer = quantizer_class(**quantizer_args)
7373
model = PytorchActivationQuantizationHolder(quantizer)
7474

75-
# Initialize a random input to quantize between -50 to 50.
76-
x = torch.from_numpy(np.random.rand(1, 3, 50, 50). astype(np.float32) * 100 - 50, )
75+
# Initialize a random input to quantize between -50 to 50. Input includes positive and negative values.
76+
x = torch.rand(1, 3, 50, 50) * 50
77+
signs = torch.from_numpy(np.where(np.indices((1, 3, 50, 50)).sum(axis=0) % 2 == 0, 1, -1).astype(np.int8))
78+
x = x * signs
7779
exp_output_tensor = model(x)
7880

7981
fx_model = symbolic_trace(model)

tests/pytorch_tests/test_fln_activation_quantizer_holder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,10 @@ def test_fln_activation_quantization_holder_save_and_load(self):
110110
quantizer = quantizer_class(**quantizer_args)
111111
model = PytorchFLNActivationQuantizationHolder(quantizer, quantization_bypass)
112112

113-
# Initialize a random input to quantize between -50 to 50.
114-
x = torch.from_numpy(np.random.rand(1, 3, 50, 50). astype(np.float32) * 100 - 50, )
113+
# Initialize a random input to quantize between -50 to 50. Input includes positive and negative values.
114+
x = torch.rand(1, 3, 50, 50) * 50
115+
signs = torch.from_numpy(np.where(np.indices((1, 3, 50, 50)).sum(axis=0) % 2 == 0, 1, -1).astype(np.int8))
116+
x = x * signs
115117
exp_output_tensor = model(x)
116118

117119
fx_model = symbolic_trace(model)

tests/pytorch_tests/test_preserving_activation_quantizer_holder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,10 @@ def test_preserving_activation_quantization_holder_save_and_load(self):
110110
quantizer = quantizer_class(**quantizer_args)
111111
model = PytorchPreservingActivationQuantizationHolder(quantizer, quantization_bypass)
112112

113-
# Initialize a random input to quantize between -50 to 50.
114-
x = torch.from_numpy(np.random.rand(1, 3, 50, 50). astype(np.float32) * 100 - 50, )
113+
# Initialize a random input to quantize between -50 to 50. Input includes positive and negative values.
114+
x = torch.rand(1, 3, 50, 50) * 50
115+
signs = torch.from_numpy(np.where(np.indices((1, 3, 50, 50)).sum(axis=0) % 2 == 0, 1, -1).astype(np.int8))
116+
x = x * signs
115117
exp_output_tensor = model(x)
116118

117119
fx_model = symbolic_trace(model)

0 commit comments

Comments
 (0)