Skip to content

Commit 8cc249e

Browse files
committed
make data_feeder support dynamic shape
1 parent a29cb4b commit 8cc249e

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

python/paddle/fluid/data_feeder.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@ def __init__(self, place, lod_level, shape, dtype):
2929
self.place = place
3030
self.lod_level = lod_level
3131
self.shape = shape
32+
self.dynamic_shape = False
33+
negtive_count = 0
34+
for s in self.shape:
35+
if s < 0:
36+
negtive_count += 1
37+
if negtive_count > 1:
38+
self.shape = None
39+
break
3240
if dtype == core.VarDesc.VarType.FP32:
3341
self.dtype = 'float32'
3442
elif dtype == core.VarDesc.VarType.INT64:
@@ -61,7 +69,9 @@ def _feed_impl_(self, data, lod, lod_level):
6169
self._feed_impl_(each_data, lod[1:], lod_level - 1)
6270

6371
def done(self):
64-
arr = numpy.array(self.data, dtype=self.dtype).reshape(self.shape)
72+
arr = numpy.array(self.data, dtype=self.dtype)
73+
if self.shape:
74+
arr = arr.reshape(self.shape)
6575
t = core.LoDTensor()
6676
t.set(arr, self.place)
6777
if self.lod_level > 0:

0 commit comments

Comments
 (0)