|
28 | 28 |
|
29 | 29 | import numpy as np |
30 | 30 | import os |
| 31 | +import warnings |
31 | 32 | from math import ceil, log2 |
32 | 33 | from qonnx.core.datatype import DataType |
33 | 34 |
|
@@ -87,31 +88,6 @@ def defines(self, var): |
87 | 88 | my_defines.append("#define EmbeddingType %s" % emb_hls_type) |
88 | 89 | self.code_gen_dict["$DEFINES$"] = my_defines |
89 | 90 |
|
90 | | - def read_npy_data(self): |
91 | | - code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") |
92 | | - dtype = self.get_input_datatype() |
93 | | - if dtype == DataType["BIPOLAR"]: |
94 | | - # use binary for bipolar storage |
95 | | - dtype = DataType["BINARY"] |
96 | | - elem_bits = dtype.bitwidth() |
97 | | - packed_bits = self.get_instream_width() |
98 | | - packed_hls_type = "ap_uint<%d>" % packed_bits |
99 | | - elem_hls_type = dtype.get_hls_datatype_str() |
100 | | - npy_type = "int64_t" |
101 | | - npy_in = "%s/input_0.npy" % code_gen_dir |
102 | | - self.code_gen_dict["$READNPYDATA$"] = [] |
103 | | - self.code_gen_dict["$READNPYDATA$"].append( |
104 | | - 'npy2apintstream<%s, %s, %d, %s>("%s", in0_%s);' |
105 | | - % ( |
106 | | - packed_hls_type, |
107 | | - elem_hls_type, |
108 | | - elem_bits, |
109 | | - npy_type, |
110 | | - npy_in, |
111 | | - self.hls_sname(), |
112 | | - ) |
113 | | - ) |
114 | | - |
115 | 91 | def dataoutstrm(self): |
116 | 92 | code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") |
117 | 93 | dtype = self.get_output_datatype() |
@@ -273,7 +249,18 @@ def execute_node(self, context, graph): |
273 | 249 | ) |
274 | 250 |
|
275 | 251 | inp = context[node.input[0]] |
276 | | - assert inp.dtype == np.int64, "Inputs must be contained in int64 ndarray" |
| 252 | + |
| 253 | + # Make sure the input has the right container datatype |
| 254 | + if inp.dtype is not np.float32: |
| 255 | + # Issue a warning to make the user aware of this type-cast |
| 256 | + warnings.warn( |
| 257 | + f"{node.name}: Changing input container datatype from " |
| 258 | + f"{inp.dtype} to {np.float32}" |
| 259 | + ) |
| 260 | + # Convert the input to floating point representation as the |
| 261 | + # container datatype |
| 262 | + inp = inp.astype(np.float32) |
| 263 | + |
277 | 264 | assert inp.shape == exp_ishape, """Input shape doesn't match expected shape.""" |
278 | 265 | export_idt = self.get_input_datatype() |
279 | 266 | odt = self.get_output_datatype() |
|
0 commit comments