Skip to content

Commit 45cfa03

Browse files
committed
Make gen_finn_dt_tensor consider the numpy type for INT and FIXED types
1 parent cadd6b2 commit 45cfa03

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/qonnx/util/basic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,10 +228,12 @@ def gen_finn_dt_tensor(finn_dt, tensor_shape):
228228
elif finn_dt == DataType["BINARY"]:
229229
tensor_values = np.random.randint(2, size=tensor_shape)
230230
elif "INT" in finn_dt.name or finn_dt == DataType["TERNARY"]:
231-
tensor_values = np.random.randint(finn_dt.min(), high=finn_dt.max() + 1, size=tensor_shape)
231+
tensor_values = np.random.randint(
232+
finn_dt.min(), high=finn_dt.max() + 1, size=tensor_shape, dtype=finn_dt.to_numpy_dt()
233+
)
232234
elif "FIXED" in finn_dt.name:
233235
int_dt = DataType["INT" + str(finn_dt.bitwidth())]
234-
tensor_values = np.random.randint(int_dt.min(), high=int_dt.max() + 1, size=tensor_shape)
236+
tensor_values = np.random.randint(int_dt.min(), high=int_dt.max() + 1, size=tensor_shape, dtype=int_dt.to_numpy_dt())
235237
tensor_values = tensor_values * finn_dt.scale_factor()
236238
elif finn_dt == DataType["FLOAT32"]:
237239
tensor_values = np.random.randn(*tensor_shape)

0 commit comments

Comments
 (0)