@@ -18,6 +18,7 @@ limitations under the License. */
18
18
19
19
#include " paddle/fluid/framework/data_transform.h"
20
20
#include " paddle/fluid/framework/executor.h"
21
+ #include " paddle/fluid/framework/lod_tensor.h"
21
22
#include " paddle/fluid/framework/operator.h"
22
23
#include " paddle/fluid/framework/shape_inference.h"
23
24
#include " paddle/fluid/framework/var_type.h"
@@ -56,13 +57,22 @@ static DDim GetDims(const Scope& scope, const std::string& name,
56
57
return DDim ({-1 });
57
58
}
58
59
59
- if (var->IsType <LoDTensor>()) {
60
- return var->Get <LoDTensor>().dims ();
61
- } else if (var->IsType <SelectedRows>()) {
62
- if (get_actual_dim) {
63
- return var->Get <SelectedRows>().value ().dims ();
60
+ if (var->IsInitialized ()) {
61
+ if (var->IsType <LoDTensor>()) {
62
+ const LoDTensor& tensor = var->Get <LoDTensor>();
63
+ if (tensor.IsInitialized ()) {
64
+ return tensor.dims ();
65
+ } else {
66
+ return DDim ({-1 });
67
+ }
68
+ } else if (var->IsType <SelectedRows>()) {
69
+ if (get_actual_dim) {
70
+ return var->Get <SelectedRows>().value ().dims ();
71
+ } else {
72
+ return var->Get <SelectedRows>().GetCompleteDims ();
73
+ }
64
74
} else {
65
- return var-> Get <SelectedRows>(). GetCompleteDims ( );
75
+ return DDim ({- 1 } );
66
76
}
67
77
} else {
68
78
return DDim ({-1 });
@@ -74,11 +84,21 @@ static std::string GetDtype(const Scope& scope, const std::string& name) {
74
84
if (var == nullptr ) {
75
85
return " " ;
76
86
}
77
- if (var->IsType <LoDTensor>()) {
78
- return DataTypeToString (ToDataType (var->Get <LoDTensor>().type ()));
79
- } else if (var->IsType <SelectedRows>()) {
80
- return DataTypeToString (
81
- ToDataType (var->Get <SelectedRows>().value ().type ()));
87
+
88
+ if (var->IsInitialized ()) {
89
+ if (var->IsType <LoDTensor>()) {
90
+ const LoDTensor& tensor = var->Get <LoDTensor>();
91
+ if (tensor.IsInitialized ()) {
92
+ return DataTypeToString (ToDataType (tensor.type ()));
93
+ } else {
94
+ return " " ;
95
+ }
96
+ } else if (var->IsType <SelectedRows>()) {
97
+ return DataTypeToString (
98
+ ToDataType (var->Get <SelectedRows>().value ().type ()));
99
+ } else {
100
+ return " " ;
101
+ }
82
102
} else {
83
103
return " " ;
84
104
}
@@ -90,8 +110,10 @@ static int GetRowSize(const Scope& scope, const std::string& name) {
90
110
return -1 ;
91
111
}
92
112
93
- if (var->IsType <SelectedRows>()) {
94
- return var->Get <SelectedRows>().rows ().size ();
113
+ if (var->IsInitialized ()) {
114
+ if (var->IsType <SelectedRows>()) {
115
+ return var->Get <SelectedRows>().rows ().size ();
116
+ }
95
117
}
96
118
97
119
return -1 ;
@@ -105,8 +127,17 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
105
127
return default_lod;
106
128
}
107
129
108
- if (var->IsType <LoDTensor>()) {
109
- return var->Get <LoDTensor>().lod ();
130
+ if (var->IsInitialized ()) {
131
+ if (var->IsType <LoDTensor>()) {
132
+ const LoDTensor& tensor = var->Get <LoDTensor>();
133
+ if (tensor.IsInitialized ()) {
134
+ return tensor.lod ();
135
+ } else {
136
+ return default_lod;
137
+ }
138
+ } else {
139
+ return default_lod;
140
+ }
110
141
} else {
111
142
return default_lod;
112
143
}
0 commit comments