Skip to content

Commit c397e13

Browse files
authored
Merge pull request #1121 from reyoung/feature/fix_ndarray_dtypes
Fix bug in DenseScanner of DataProviderConverter.
2 parents 0955977 + 2e47c9d commit c397e13

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

paddle/py_paddle/dataprovider_converter.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ def finish_scan(self, argument):
3434

3535

3636
class DenseScanner(IScanner):
37+
"""
38+
:type __mat__: numpy.ndarray
39+
"""
40+
3741
def __init__(self, input_type, pos):
3842
IScanner.__init__(self, input_type, pos)
3943
self.__mat__ = None
@@ -47,6 +51,8 @@ def scan(self, dat):
4751
def finish_scan(self, argument):
4852
assert isinstance(argument, swig_paddle.Arguments)
4953
assert isinstance(self.input_type, dp2.InputType)
54+
if self.__mat__.dtype != numpy.float32:
55+
self.__mat__ = self.__mat__.astype(numpy.float32)
5056
m = swig_paddle.Matrix.createDenseFromNumpy(self.__mat__, True, False)
5157
argument.setSlotValue(self.pos, m)
5258

0 commit comments

Comments
 (0)