Skip to content

Commit 6c0b383

Browse files
authored
Add VarType::STEP_SCOPES for RNN (#5056)
1 parent ee998a9 commit 6c0b383

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

paddle/framework/framework.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ message VarDesc {
115115
SELECTED_ROWS = 2;
116116
FEED_MINIBATCH = 3;
117117
FETCH_LIST = 4;
118+
STEP_SCOPES = 5;
118119
}
119120
required string name = 1;
120121
required VarType type = 2;

paddle/pybind/protobuf.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,8 @@ void BindVarDsec(py::module &m) {
224224
.value("LOD_TENSOR", VarDesc::LOD_TENSOR)
225225
.value("SELECTED_ROWS", VarDesc::SELECTED_ROWS)
226226
.value("FEED_MINIBATCH", VarDesc::FEED_MINIBATCH)
227-
.value("FETCH_LIST", VarDesc::FETCH_LIST);
227+
.value("FETCH_LIST", VarDesc::FETCH_LIST)
228+
.value("STEP_SCOPES", VarDesc::STEP_SCOPES);
228229
}
229230

230231
void BindOpDesc(py::module &m) {

python/paddle/v2/framework/tests/test_variable.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import unittest
2-
from paddle.v2.framework.framework import Variable, g_program
2+
from paddle.v2.framework.framework import Variable, g_program, Program
33
import paddle.v2.framework.core as core
44
import numpy as np
55

@@ -36,6 +36,13 @@ def test_var(self):
3636
self.assertRaises(ValueError,
3737
lambda: b.create_var(name="fc.w", shape=(24, 100)))
3838

39+
def test_step_scopes(self):
40+
prog = Program()
41+
b = prog.current_block()
42+
var = b.create_var(
43+
name='step_scopes', type=core.VarDesc.VarType.STEP_SCOPES)
44+
self.assertEqual(core.VarDesc.VarType.STEP_SCOPES, var.type)
45+
3946

4047
if __name__ == '__main__':
4148
unittest.main()

0 commit comments

Comments
 (0)