|
10 | 10 | limitations under the License. */
|
11 | 11 |
|
12 | 12 | #include "paddle/fluid/operators/spectral_norm_op.h"
|
| 13 | + |
| 14 | +#include <memory> |
| 15 | + |
13 | 16 | #include "paddle/fluid/framework/op_registry.h"
|
14 | 17 |
|
15 | 18 | namespace paddle {
|
@@ -156,6 +159,28 @@ class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker {
|
156 | 159 | }
|
157 | 160 | };
|
158 | 161 |
|
| 162 | +class SpectralNormGradOpDescMaker : public framework::SingleGradOpDescMaker { |
| 163 | + public: |
| 164 | + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; |
| 165 | + |
| 166 | + protected: |
| 167 | + std::unique_ptr<framework::OpDesc> Apply() const override { |
| 168 | + std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); |
| 169 | + op->SetType("spectral_norm_grad"); |
| 170 | + |
| 171 | + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); |
| 172 | + op->SetInput("Weight", Input("Weight")); |
| 173 | + op->SetInput("U", Input("U")); |
| 174 | + op->SetInput("V", Input("V")); |
| 175 | + |
| 176 | + op->SetOutput(framework::GradVarName("Weight"), InputGrad("Weight")); |
| 177 | + |
| 178 | + op->SetAttrMap(Attrs()); |
| 179 | + |
| 180 | + return op; |
| 181 | + } |
| 182 | +}; |
| 183 | + |
159 | 184 | class SpectralNormOpGrad : public framework::OperatorWithKernel {
|
160 | 185 | public:
|
161 | 186 | using framework::OperatorWithKernel::OperatorWithKernel;
|
@@ -185,7 +210,7 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel {
|
185 | 210 |
|
186 | 211 | namespace ops = paddle::operators;
|
187 | 212 | REGISTER_OPERATOR(spectral_norm, ops::SpectralNormOp, ops::SpectralNormOpMaker,
|
188 |
| - paddle::framework::DefaultGradOpDescMaker<true>); |
| 213 | + ops::SpectralNormGradOpDescMaker); |
189 | 214 | REGISTER_OPERATOR(spectral_norm_grad, ops::SpectralNormOpGrad);
|
190 | 215 | REGISTER_OP_CPU_KERNEL(
|
191 | 216 | spectral_norm,
|
|
0 commit comments