File tree Expand file tree Collapse file tree 3 files changed +10
-3
lines changed Expand file tree Collapse file tree 3 files changed +10
-3
lines changed Original file line number Diff line number Diff line change @@ -22,6 +22,12 @@ std::vector<framework::DDim> InferShapeContext::GetInputsDim(
22
22
return GetDims (names);
23
23
}
24
24
25
+ DDim InferShapeContext::GetInputsElementDim (const std::string &name,
26
+ int idx) const {
27
+ const std::vector<std::string> &names = Inputs (name);
28
+ return this ->GetDim (names[idx]);
29
+ }
30
+
25
31
void InferShapeContext::SetOutputsDim (
26
32
const std::string &name, const std::vector<framework::DDim> &dims) {
27
33
auto &names = Outputs (name);
Original file line number Diff line number Diff line change @@ -37,6 +37,7 @@ class InferShapeContext {
37
37
virtual framework::DDim GetInputDim (const std::string &name) const = 0;
38
38
39
39
std::vector<framework::DDim> GetInputsDim (const std::string &name) const ;
40
+ DDim GetInputsElementDim (const std::string &name, int idx) const ;
40
41
41
42
virtual void SetOutputDim (const std::string &name, const DDim &dim) = 0;
42
43
void SetOutputsDim (const std::string &name,
Original file line number Diff line number Diff line change @@ -287,21 +287,21 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
287
287
288
288
auto p_names = ctx->Inputs (kParameters );
289
289
auto pg_names = ctx->Outputs (kParamGrads );
290
- auto dims = ctx->GetInputsDim (kParameters );
291
290
auto var_types = ctx->GetInputsVarType (kParameters );
292
291
std::vector<std::string> names_to_set;
293
292
std::vector<framework::DDim> dims_to_set;
294
293
for (size_t i = 0 ; i < p_names.size (); ++i) {
295
294
if (pg_names[i] == framework::kEmptyVarName ) {
296
295
continue ;
297
296
}
297
+ auto dims = ctx->GetInputsElementDim (kParameters , i);
298
298
if (var_types[i] == framework::VarDesc::LOD_TENSOR) {
299
299
names_to_set.push_back (pg_names[i]);
300
- dims_to_set.push_back (dims[i] );
300
+ dims_to_set.push_back (dims);
301
301
} else if (var_types[i] == framework::VarDesc::LOD_TENSOR_ARRAY) {
302
302
// not sure how to set the dim of LOD_TENSOR_ARRAY
303
303
names_to_set.push_back (pg_names[i]);
304
- dims_to_set.push_back (dims[i] );
304
+ dims_to_set.push_back (dims);
305
305
}
306
306
}
307
307
ctx->SetDims (names_to_set, dims_to_set);
You can’t perform that action at this time.
0 commit comments