File tree Expand file tree Collapse file tree 2 files changed +43
-1
lines changed
python/paddle/fluid/tests/unittests Expand file tree Collapse file tree 2 files changed +43
-1
lines changed Original file line number Diff line number Diff line change @@ -47,6 +47,11 @@ class ExpandOp : public framework::OperatorWithKernel {
47
47
out_shape[i] = x_dims[i] * expand_times[i];
48
48
}
49
49
50
+ // set the first dim to -1 in compile time
51
+ if (!ctx->IsRuntime ()) {
52
+ out_shape[0 ] = x_dims[0 ];
53
+ }
54
+
50
55
ctx->SetOutputDim (" Out" , framework::make_ddim (out_shape));
51
56
if (out_shape[0 ] == x_dims[0 ]) {
52
57
ctx->ShareLoD (" X" , " Out" );
@@ -109,7 +114,16 @@ class ExpandGradOp : public framework::OperatorWithKernel {
109
114
ctx->Attrs ().Get <std::vector<int >>(" expand_times" );
110
115
auto out_dims = ctx->GetInputDim (framework::GradVarName (" Out" ));
111
116
112
- for (size_t i = 0 ; i < expand_times.size (); ++i) {
117
+ size_t start_pos = 0u ;
118
+ if (!ctx->IsRuntime ()) {
119
+ PADDLE_ENFORCE_EQ (
120
+ x_dims[0 ], out_dims[0 ],
121
+ " The first dimension size of Input(Out@GRAD) should be "
122
+ " equal to the crroresponding dimension size of Input(X)" );
123
+ start_pos = 1u ;
124
+ }
125
+
126
+ for (size_t i = start_pos; i < expand_times.size (); ++i) {
113
127
PADDLE_ENFORCE_EQ (x_dims[i] * expand_times[i], out_dims[i],
114
128
" Each dimension size of Input(Out@GRAD) should be "
115
129
" equal to multiplication of crroresponding dimension "
Original file line number Diff line number Diff line change @@ -83,6 +83,34 @@ def test_mul_op(self):
83
83
mul_op_desc .infer_shape (block )
84
84
self .assertEqual (out .shape (), [x_shape [0 ], y_shape [1 ]])
85
85
86
+ def test_expand_op (self ):
87
+ prog = core .ProgramDesc ()
88
+ self .assertIsNotNone (prog )
89
+ block = prog .block (0 )
90
+ self .assertIsNotNone (block )
91
+
92
+ shape = [- 1 , 20 ]
93
+ expand_times = [3 , 1 ]
94
+
95
+ # prepare input/output
96
+ x1 = block .var (six .b ("x" ))
97
+ x1 .set_type (core .VarDesc .VarType .LOD_TENSOR )
98
+ x1 .set_shape (shape )
99
+
100
+ out = block .var (six .b ("out" ))
101
+ out .set_type (core .VarDesc .VarType .LOD_TENSOR )
102
+
103
+ # prepare the operator
104
+ sum_op_desc = block .append_op ()
105
+ sum_op_desc .set_type ("expand" )
106
+ sum_op_desc .set_input ("X" , ["x" ])
107
+ sum_op_desc .set_output ("Out" , ["out" ])
108
+ sum_op_desc ._set_attr ('expand_times' , expand_times )
109
+
110
+ sum_op_desc .check_attrs ()
111
+ sum_op_desc .infer_shape (block )
112
+ self .assertEqual (out .shape (), shape )
113
+
86
114
87
115
if __name__ == '__main__' :
88
116
unittest .main ()
You can’t perform that action at this time.
0 commit comments