Skip to content

Commit 3de81d0

Browse files
committed
[Lookup] Reintroduce prior assertion as warning and fix type comparison
1 parent 6ace464 commit 3de81d0

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

src/finn/custom_op/fpgadataflow/hls/lookup_hls.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import numpy as np
3030
import os
31+
import warnings
3132
from math import ceil, log2
3233
from qonnx.core.datatype import DataType
3334

@@ -273,7 +274,18 @@ def execute_node(self, context, graph):
273274
)
274275

275276
inp = context[node.input[0]]
276-
# assert inp.dtype == np.int64, "Inputs must be contained in int64 ndarray"
277+
278+
# Make sure the input has the right container datatype
279+
if inp.dtype is not np.float32:
280+
# Issue a warning to make the user aware of this type-cast
281+
warnings.warn(
282+
f"{node.name}: Changing input container datatype from "
283+
f"{inp.dtype} to {np.float32}"
284+
)
285+
# Convert the input to floating point representation as the
286+
# container datatype
287+
inp = inp.astype(np.float32)
288+
277289
assert inp.shape == exp_ishape, """Input shape doesn't match expected shape."""
278290
export_idt = self.get_input_datatype()
279291
odt = self.get_output_datatype()

src/finn/custom_op/fpgadataflow/rtl/streamingfifo_rtl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,11 @@ def execute_node(self, context, graph):
134134
code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen")
135135
# create a npy file for the input of the node
136136

137-
# Make sure the inpout has the right container datatype
138-
if inp.dtype != np.float32:
137+
# Make sure the input has the right container datatype
138+
if inp.dtype is not np.float32:
139139
# Issue a warning to make the user aware of this type-cast
140140
warnings.warn(
141-
f"{node.name}: Changing input datatype from "
141+
f"{node.name}: Changing input container datatype from "
142142
f"{inp.dtype} to {np.float32}"
143143
)
144144
# Convert the input to floating point representation as the

0 commit comments

Comments
 (0)