Skip to content

Commit 43c5626

Browse files
author
Feiyu Chan
authored
cherry pick: (#24269) fix kron_op: when only one input needs gradient (#24270)
* fix kron_op: when only one input needs gradient, test=develop * fix a typo in paddle.complex.matmul, test=release/1.8
1 parent f493268 commit 43c5626

File tree

6 files changed

+77
-38
lines changed

6 files changed

+77
-38
lines changed

paddle/fluid/operators/kron_op.cc

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,15 @@ class KronGradOp : public framework::OperatorWithKernel {
9999
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "kron_grad");
100100
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
101101
framework::GradVarName("Out"), "kron_grad");
102-
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
103-
framework::GradVarName("X"), "kron_grad");
104-
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Y")), "Output",
105-
framework::GradVarName("Y"), "kron_grad");
106102

107103
auto x_grad_name = framework::GradVarName("X");
108104
auto y_grad_name = framework::GradVarName("Y");
109-
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
110-
ctx->ShareLoD("X", /*->*/ x_grad_name);
111-
ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y"));
112-
ctx->ShareLoD("Y", /*->*/ y_grad_name);
105+
if (ctx->HasOutput(x_grad_name)) {
106+
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
107+
}
108+
if (ctx->HasOutput(y_grad_name)) {
109+
ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y"));
110+
}
113111
}
114112

115113
protected:

paddle/fluid/operators/kron_op.h

Lines changed: 61 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,14 @@ struct KronGradElemFunctor {
147147
index_b += stride_b_[i] * pos_bi;
148148
}
149149

150-
size_t index_out_a = index_a * numel_b_ + index_b;
151-
size_t index_out_b = index_b * numel_a_ + index_a;
152-
153-
dout_a_[index_out_a] = dout_[idx] * B_[index_b];
154-
dout_b_[index_out_b] = dout_[idx] * A_[index_a];
150+
if (dout_a_) {
151+
size_t index_out_a = index_a * numel_b_ + index_b;
152+
dout_a_[index_out_a] = dout_[idx] * B_[index_b];
153+
}
154+
if (dout_b_) {
155+
size_t index_out_b = index_b * numel_a_ + index_a;
156+
dout_b_[index_out_b] = dout_[idx] * A_[index_a];
157+
}
155158
}
156159

157160
private:
@@ -222,35 +225,50 @@ struct KronGradOpFunctor {
222225
// dout_x: dout * kron(ones(X), Y) re-aranged in shape (numel_x, numel_y)
223226
// dout_y: dout * kron(X, ones(Y)) re-aranged in shaoe (numel_y, numel_x)
224227
framework::Tensor dout_x;
225-
dout_x.mutable_data<T>({numel_x, numel_y}, dev_ctx.GetPlace());
228+
T* p_dout_x = nullptr;
229+
if (dx) {
230+
dout_x.mutable_data<T>({numel_x, numel_y}, dev_ctx.GetPlace());
231+
p_dout_x = dout_x.data<T>();
232+
}
226233
framework::Tensor dout_y;
227-
dout_y.mutable_data<T>({numel_y, numel_x}, dev_ctx.GetPlace());
234+
T* p_dout_y = nullptr;
235+
if (dy) {
236+
dout_y.mutable_data<T>({numel_y, numel_x}, dev_ctx.GetPlace());
237+
p_dout_y = dout_y.data<T>();
238+
}
228239

229240
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
230241
KronGradElemFunctor<T> func(dout.data<T>(), x.data<T>(), y.data<T>(),
231-
dout_x.data<T>(), dout_y.data<T>(),
232-
p_stride_dout, p_stride_x, p_stride_y,
233-
p_shape_y, numel_x, numel_y, ndims);
242+
p_dout_x, p_dout_y, p_stride_dout, p_stride_x,
243+
p_stride_y, p_shape_y, numel_x, numel_y, ndims);
234244
for_range(func);
235245

236246
// reduce_sum along aixs 1
237247
#if __NVCC__
238248
auto stream = dev_ctx.stream(); // it is a cuda device_context
239-
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
240-
dout_x, dx, {1}, static_cast<T>(0), cub::Sum(), IdentityFunctor<T>(),
241-
stream);
242-
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
243-
dout_y, dy, {1}, static_cast<T>(0), cub::Sum(), IdentityFunctor<T>(),
244-
stream);
249+
if (dx) {
250+
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
251+
dout_x, dx, {1}, static_cast<T>(0), cub::Sum(), IdentityFunctor<T>(),
252+
stream);
253+
}
254+
if (dy) {
255+
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
256+
dout_y, dy, {1}, static_cast<T>(0), cub::Sum(), IdentityFunctor<T>(),
257+
stream);
258+
}
245259
#else
246-
auto eigen_dout_x = framework::EigenMatrix<T>::Reshape(dout_x, 1);
247-
auto eigen_dout_y = framework::EigenMatrix<T>::Reshape(dout_y, 1);
248-
auto eigen_vec_dx = framework::EigenVector<T>::Flatten(*dx);
249-
auto eigen_vec_dy = framework::EigenVector<T>::Flatten(*dy);
250260
auto* place = dev_ctx.eigen_device();
251261
Eigen::array<int, 1> reduce_dim = {1};
252-
eigen_vec_dx.device(*place) = eigen_dout_x.sum(reduce_dim);
253-
eigen_vec_dy.device(*place) = eigen_dout_y.sum(reduce_dim);
262+
if (dx) {
263+
auto eigen_dout_x = framework::EigenMatrix<T>::Reshape(dout_x, 1);
264+
auto eigen_vec_dx = framework::EigenVector<T>::Flatten(*dx);
265+
eigen_vec_dx.device(*place) = eigen_dout_x.sum(reduce_dim);
266+
}
267+
if (dy) {
268+
auto eigen_dout_y = framework::EigenMatrix<T>::Reshape(dout_y, 1);
269+
auto eigen_vec_dy = framework::EigenVector<T>::Flatten(*dy);
270+
eigen_vec_dy.device(*place) = eigen_dout_y.sum(reduce_dim);
271+
}
254272
#endif
255273
}
256274
};
@@ -307,17 +325,33 @@ class KronGradKernel : public framework::OpKernel<T> {
307325

308326
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
309327
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
310-
dx->mutable_data<T>(ctx.GetPlace());
311-
dy->mutable_data<T>(ctx.GetPlace());
328+
if (dx) {
329+
dx->mutable_data<T>(ctx.GetPlace());
330+
}
331+
if (dy) {
332+
dy->mutable_data<T>(ctx.GetPlace());
333+
}
312334

313335
int ndims = dout->dims().size();
314336
framework::Tensor xx = UnsqueezeTo(*x, ndims);
315-
framework::Tensor dxx = UnsqueezeTo(*dx, ndims);
316337
framework::Tensor yy = UnsqueezeTo(*y, ndims);
317-
framework::Tensor dyy = UnsqueezeTo(*dy, ndims);
338+
339+
framework::Tensor* pdxx = nullptr;
340+
framework::Tensor* pdyy = nullptr;
341+
framework::Tensor dxx;
342+
framework::Tensor dyy;
343+
if (dx) {
344+
dxx = UnsqueezeTo(*dx, ndims);
345+
pdxx = &dxx;
346+
}
347+
348+
if (dy) {
349+
dyy = UnsqueezeTo(*dy, ndims);
350+
pdyy = &dyy;
351+
}
318352

319353
KronGradOpFunctor<DeviceContext, T> func;
320-
func(dev_ctx, *dout, xx, yy, &dxx, &dyy);
354+
func(dev_ctx, *dout, xx, yy, pdxx, pdyy);
321355
}
322356
};
323357

python/paddle/complex/tensor/linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None):
2626
2727
Args:
2828
x (ComplexVariable|Variable): The first input, can be a ComplexVariable
29-
with data type complex32 or complex64, or a Variable with data type
29+
with data type complex64 or complex128, or a Variable with data type
3030
float32 or float64.
3131
y (ComplexVariable|Variable): The second input, can be a ComplexVariable
32-
with data type complex32 or complex64, or a Variable with data type
32+
with data type complex64 or complex128, or a Variable with data type
3333
float32 or float64.
3434
transpose_x (bool): Whether to transpose :math:`x` before multiplication.
3535
transpose_y (bool): Whether to transpose :math:`y` before multiplication.

python/paddle/complex/tensor/math.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ def kron(x, y, name=None):
367367
368368
import numpy as np
369369
import paddle
370+
from paddle import fluid
370371
import paddle.fluid.dygraph as dg
371372
372373
a = np.array([[1.0+1.0j, 2.0+1.0j], [3.0+1.0j, 4.0+1.0j]])

python/paddle/fluid/tests/unittests/test_kron_op.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ def test_check_output(self):
4242
def test_check_grad(self):
4343
self.check_grad(['X', 'Y'], 'Out')
4444

45+
def test_check_grad_ignore_x(self):
46+
self.check_grad(['Y'], 'Out', no_grad_set=set('X'))
47+
48+
def test_check_grad_ignore_y(self):
49+
self.check_grad(['X'], 'Out', no_grad_set=set('Y'))
50+
4551

4652
class TestKronOp2(TestKronOp):
4753
def setUp(self):

python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
# check no_grad_set is None
16-
NOT_CHECK_OP_LIST = ['deformable_conv', 'row_conv']
16+
NOT_CHECK_OP_LIST = ['deformable_conv', 'row_conv', 'kron']
1717

1818
# TODO(Shixiaowei02): Check if the items do not need fix.
1919
# no_grad_set has value in NEED_TO_FIX_OP_LIST

0 commit comments

Comments
 (0)