Skip to content

Commit 5f1c92a

Browse files
authored
Merge pull request #16450 from zhhsplendid/del-redundant-op-var-reg
Add SpectralNormGradOpDescMaker
2 parents ecc3088 + 3909108 commit 5f1c92a

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

paddle/fluid/operators/spectral_norm_op.cc

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
limitations under the License. */
1111

1212
#include "paddle/fluid/operators/spectral_norm_op.h"
13+
14+
#include <memory>
15+
1316
#include "paddle/fluid/framework/op_registry.h"
1417

1518
namespace paddle {
@@ -156,6 +159,28 @@ class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker {
156159
}
157160
};
158161

162+
class SpectralNormGradOpDescMaker : public framework::SingleGradOpDescMaker {
163+
public:
164+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
165+
166+
protected:
167+
std::unique_ptr<framework::OpDesc> Apply() const override {
168+
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
169+
op->SetType("spectral_norm_grad");
170+
171+
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
172+
op->SetInput("Weight", Input("Weight"));
173+
op->SetInput("U", Input("U"));
174+
op->SetInput("V", Input("V"));
175+
176+
op->SetOutput(framework::GradVarName("Weight"), InputGrad("Weight"));
177+
178+
op->SetAttrMap(Attrs());
179+
180+
return op;
181+
}
182+
};
183+
159184
class SpectralNormOpGrad : public framework::OperatorWithKernel {
160185
public:
161186
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -185,7 +210,7 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel {
185210

186211
namespace ops = paddle::operators;
187212
REGISTER_OPERATOR(spectral_norm, ops::SpectralNormOp, ops::SpectralNormOpMaker,
188-
paddle::framework::DefaultGradOpDescMaker<true>);
213+
ops::SpectralNormGradOpDescMaker);
189214
REGISTER_OPERATOR(spectral_norm_grad, ops::SpectralNormOpGrad);
190215
REGISTER_OP_CPU_KERNEL(
191216
spectral_norm,

0 commit comments

Comments
 (0)