Skip to content

Commit 76a3678

Browse files
authored
fix lod_reset check dtype (#24227)
1 parent c36c67f commit 76a3678

File tree

3 files changed

+2
-20
lines changed

3 files changed

+2
-20
lines changed

python/paddle/fluid/layers/nn.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7660,8 +7660,7 @@ def lod_reset(x, y=None, target_lod=None):
76607660
out = helper.create_variable_for_type_inference(dtype=x.dtype)
76617661
if y is not None:
76627662
check_type(y, 'y', (Variable), 'lod_reset')
7663-
if y.lod_level == 0:
7664-
check_variable_and_dtype(y, 'y', ['int32'], 'lod_reset')
7663+
#TODO: check y.lod_level = 0 dtype
76657664
helper.append_op(
76667665
type="lod_reset", inputs={'X': x,
76677666
'Y': y}, outputs={'Out': out})
@@ -7732,8 +7731,7 @@ def lod_append(x, level):
77327731

77337732
if isinstance(level, Variable):
77347733
inputs['Y'] = level
7735-
if level.lod_level == 0:
7736-
check_variable_and_dtype(level, 'level', ['int32'], 'lod_append')
7734+
#TODO: check y.lod_level = 0 dtype
77377735
else:
77387736
attrs['target_lod'] = level
77397737
helper.append_op(

python/paddle/fluid/tests/unittests/test_lod_append_op.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,6 @@ def test_error(self):
6767
name='level3' + dtype, shape=[4], dtype='int32', lod_level=2)
6868
self.assertRaises(TypeError, fluid.layers.lod_append, x3, level3)
6969

70-
# Input(level) dtype must be int32 when lod_level=0
71-
for dtype in ["bool", "float16", "float32", "float64", "int64"]:
72-
x4 = fluid.layers.data(
73-
name='x4' + dtype, shape=[4], dtype='float32')
74-
level4 = fluid.layers.data(
75-
name='level4_' + dtype, shape=[4], dtype=dtype, lod_level=0)
76-
self.assertRaises(TypeError, fluid.layers.lod_append, x4, level4)
77-
7870

7971
if __name__ == "__main__":
8072
unittest.main()

python/paddle/fluid/tests/unittests/test_lod_reset_op.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,6 @@ def test_errors(self):
150150
name='y2' + dtype, shape=[4], dtype='int32', lod_level=2)
151151
self.assertRaises(TypeError, fluid.layers.lod_reset, x2, y2)
152152

153-
# Input(y) dtype must be int32 when lod_level=0
154-
for dtype in ["bool", "float16", "float32", "float64", "int64"]:
155-
x3 = fluid.layers.data(
156-
name='x3' + dtype, shape=[4], dtype='float32')
157-
y3 = fluid.layers.data(
158-
name='y3' + dtype, shape=[4], dtype=dtype, lod_level=0)
159-
self.assertRaises(TypeError, fluid.layers.lod_reset, x3, y3)
160-
161153

162154
if __name__ == '__main__':
163155
unittest.main()

0 commit comments

Comments
 (0)