@@ -33,12 +33,14 @@ class ClipByNormKernel : public framework::OpKernel<T> {
33
33
void Compute (const framework::ExecutionContext& context) const override {
34
34
auto max_norm = context.Attr <T>(" max_norm" );
35
35
auto in_var = context.InputVar (" X" );
36
- auto * output = context.Output <Tensor>(" Out" );
37
- output->mutable_data <T>(context.GetPlace ());
38
36
37
+ Tensor* output = nullptr ;
39
38
const Tensor* input = nullptr ;
40
39
if (in_var->IsType <framework::LoDTensor>()) {
41
40
input = context.Input <Tensor>(" X" );
41
+
42
+ output = context.Output <Tensor>(" Out" );
43
+ output->mutable_data <T>(context.GetPlace ());
42
44
} else if (in_var->IsType <framework::SelectedRows>()) {
43
45
auto * x = context.Input <framework::SelectedRows>(" X" );
44
46
@@ -50,6 +52,11 @@ class ClipByNormKernel : public framework::OpKernel<T> {
50
52
merge_func (context.template device_context <DeviceContext>(), *x,
51
53
merged_input);
52
54
input = &(merged_input->value ());
55
+
56
+ auto * output_selected_rows = context.Output <SelectedRows>(" Out" );
57
+ output_selected_rows->set_rows (merged_input.rows ());
58
+ output = output_selected_rows->mutable_data ();
59
+ output->Resize (framework::make_ddim (merged_input.value ().dims ()));
53
60
} else {
54
61
PADDLE_THROW (" Unexpected branch, input variable type is %s" ,
55
62
in_var->Type ().name ());
0 commit comments