Skip to content

Commit 76eede6

Browse files
authored
Merge pull request Xilinx#1267 from iksnagreb/fix/lookup
[Lookup] Relax input datatype constraints
2 parents a471350 + 32fbff0 commit 76eede6

File tree

2 files changed

+25
-30
lines changed

2 files changed

+25
-30
lines changed

src/finn/custom_op/fpgadataflow/hls/lookup_hls.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import numpy as np
3030
import os
31+
import warnings
3132
from math import ceil, log2
3233
from qonnx.core.datatype import DataType
3334

@@ -87,31 +88,6 @@ def defines(self, var):
8788
my_defines.append("#define EmbeddingType %s" % emb_hls_type)
8889
self.code_gen_dict["$DEFINES$"] = my_defines
8990

90-
def read_npy_data(self):
91-
code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
92-
dtype = self.get_input_datatype()
93-
if dtype == DataType["BIPOLAR"]:
94-
# use binary for bipolar storage
95-
dtype = DataType["BINARY"]
96-
elem_bits = dtype.bitwidth()
97-
packed_bits = self.get_instream_width()
98-
packed_hls_type = "ap_uint<%d>" % packed_bits
99-
elem_hls_type = dtype.get_hls_datatype_str()
100-
npy_type = "int64_t"
101-
npy_in = "%s/input_0.npy" % code_gen_dir
102-
self.code_gen_dict["$READNPYDATA$"] = []
103-
self.code_gen_dict["$READNPYDATA$"].append(
104-
'npy2apintstream<%s, %s, %d, %s>("%s", in0_%s);'
105-
% (
106-
packed_hls_type,
107-
elem_hls_type,
108-
elem_bits,
109-
npy_type,
110-
npy_in,
111-
self.hls_sname(),
112-
)
113-
)
114-
11591
def dataoutstrm(self):
11692
code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
11793
dtype = self.get_output_datatype()
@@ -273,7 +249,18 @@ def execute_node(self, context, graph):
273249
)
274250

275251
inp = context[node.input[0]]
276-
assert inp.dtype == np.int64, "Inputs must be contained in int64 ndarray"
252+
253+
# Make sure the input has the right container datatype
254+
if inp.dtype is not np.float32:
255+
# Issue a warning to make the user aware of this type-cast
256+
warnings.warn(
257+
f"{node.name}: Changing input container datatype from "
258+
f"{inp.dtype} to {np.float32}"
259+
)
260+
# Convert the input to floating point representation as the
261+
# container datatype
262+
inp = inp.astype(np.float32)
263+
277264
assert inp.shape == exp_ishape, """Input shape doesn't match expected shape."""
278265
export_idt = self.get_input_datatype()
279266
odt = self.get_output_datatype()

src/finn/custom_op/fpgadataflow/rtl/streamingfifo_rtl.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,18 @@ def execute_node(self, context, graph):
133133
elif mode == "rtlsim":
134134
code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen")
135135
# create a npy file for the input of the node
136-
assert (
137-
str(inp.dtype) == "float32"
138-
), """Input datatype is
139-
not float32 as expected."""
136+
137+
# Make sure the input has the right container datatype
138+
if inp.dtype is not np.float32:
139+
# Issue a warning to make the user aware of this type-cast
140+
warnings.warn(
141+
f"{node.name}: Changing input container datatype from "
142+
f"{inp.dtype} to {np.float32}"
143+
)
144+
# Convert the input to floating point representation as the
145+
# container datatype
146+
inp = inp.astype(np.float32)
147+
140148
expected_inp_shape = self.get_folded_input_shape()
141149
reshaped_input = inp.reshape(expected_inp_shape)
142150
if DataType[self.get_nodeattr("dataType")] == DataType["BIPOLAR"]:

0 commit comments

Comments
 (0)