File tree Expand file tree Collapse file tree 2 files changed +33
-0
lines changed
python/paddle/fluid/tests/unittests Expand file tree Collapse file tree 2 files changed +33
-0
lines changed Original file line number Diff line number Diff line change @@ -47,6 +47,11 @@ class ExpandOp : public framework::OperatorWithKernel {
47
47
out_shape[i] = x_dims[i] * expand_times[i];
48
48
}
49
49
50
+ // set the first dim to -1 in compile time
51
+ if (!ctx->IsRuntime ()) {
52
+ out_shape[0 ] = x_dims[0 ];
53
+ }
54
+
50
55
ctx->SetOutputDim (" Out" , framework::make_ddim (out_shape));
51
56
if (out_shape[0 ] == x_dims[0 ]) {
52
57
ctx->ShareLoD (" X" , " Out" );
Original file line number Diff line number Diff line change @@ -83,6 +83,34 @@ def test_mul_op(self):
83
83
mul_op_desc .infer_shape (block )
84
84
self .assertEqual (out .shape (), [x_shape [0 ], y_shape [1 ]])
85
85
86
+ def test_expand_op (self ):
87
+ prog = core .ProgramDesc ()
88
+ self .assertIsNotNone (prog )
89
+ block = prog .block (0 )
90
+ self .assertIsNotNone (block )
91
+
92
+ shape = [- 1 , 20 ]
93
+ expand_times = [3 , 1 ]
94
+
95
+ # prepare input/output
96
+ x1 = block .var (six .b ("x" ))
97
+ x1 .set_type (core .VarDesc .VarType .LOD_TENSOR )
98
+ x1 .set_shape (shape )
99
+
100
+ out = block .var (six .b ("out" ))
101
+ out .set_type (core .VarDesc .VarType .LOD_TENSOR )
102
+
103
+ # prepare the operator
104
+ sum_op_desc = block .append_op ()
105
+ sum_op_desc .set_type ("expand" )
106
+ sum_op_desc .set_input ("X" , ["x" ])
107
+ sum_op_desc .set_output ("Out" , ["out" ])
108
+ sum_op_desc ._set_attr ('expand_times' , expand_times )
109
+
110
+ sum_op_desc .check_attrs ()
111
+ sum_op_desc .infer_shape (block )
112
+ self .assertEqual (out .shape (), shape )
113
+
86
114
87
115
if __name__ == '__main__' :
88
116
unittest .main ()
You can’t perform that action at this time.
0 commit comments