Skip to content

Commit f20fc95

Browse files
committed
Resize output ddims and rows
1 parent 6730882 commit f20fc95

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

paddle/fluid/operators/clip_by_norm_op.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,14 @@ class ClipByNormKernel : public framework::OpKernel<T> {
3333
void Compute(const framework::ExecutionContext& context) const override {
3434
auto max_norm = context.Attr<T>("max_norm");
3535
auto in_var = context.InputVar("X");
36-
auto* output = context.Output<Tensor>("Out");
37-
output->mutable_data<T>(context.GetPlace());
3836

37+
Tensor* output = nullptr;
3938
const Tensor* input = nullptr;
4039
if (in_var->IsType<framework::LoDTensor>()) {
4140
input = context.Input<Tensor>("X");
41+
42+
output = context.Output<Tensor>("Out");
43+
output->mutable_data<T>(context.GetPlace());
4244
} else if (in_var->IsType<framework::SelectedRows>()) {
4345
auto* x = context.Input<framework::SelectedRows>("X");
4446

@@ -50,6 +52,11 @@ class ClipByNormKernel : public framework::OpKernel<T> {
5052
merge_func(context.template device_context<DeviceContext>(), *x,
5153
merged_input);
5254
input = &(merged_input->value());
55+
56+
auto* output_selected_rows = context.Output<SelectedRows>("Out");
57+
output_selected_rows->set_rows(merged_input.rows());
58+
output = output_selected_rows->mutable_data();
59+
output->Resize(framework::make_ddim(merged_input.value().dims()));
5360
} else {
5461
PADDLE_THROW("Unexpected branch, input variable type is %s",
5562
in_var->Type().name());

0 commit comments

Comments
 (0)