@@ -534,8 +534,8 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
534
534
const framework::Tensor& dout, int axis,
535
535
framework::Tensor* dx, framework::Tensor* dy,
536
536
DX_OP dx_op, DY_OP dy_op) {
537
- const framework::DDim x_dim = x.dims ();
538
- const framework::DDim y_dim = y.dims ();
537
+ const framework::DDim& x_dim = x.dims ();
538
+ const framework::DDim& y_dim = y.dims ();
539
539
if (x.dims () == y.dims ()) {
540
540
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
541
541
ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
@@ -558,19 +558,19 @@ void ElemwiseExplicitGradCompute(const framework::ExecutionContext& ctx,
558
558
framework::Tensor* dx, framework::Tensor* dy,
559
559
DX_OP dx_op, DY_OP dy_op) {
560
560
if (dy == nullptr ) {
561
- const framework::DDim dx_dims = dout.dims ();
561
+ const framework::DDim& dx_dims = dout.dims ();
562
562
auto dy_dims = dx_dims;
563
563
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
564
564
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
565
565
} else {
566
566
if (dout.dims () == dy->dims ()) {
567
- const framework::DDim dx_dims = dout.dims ();
568
- const framework::DDim dy_dims = dy->dims ();
567
+ const framework::DDim& dx_dims = dout.dims ();
568
+ const framework::DDim& dy_dims = dy->dims ();
569
569
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
570
570
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
571
571
} else { // Y is a scalar
572
572
auto dx_dims = dout.dims ();
573
- const framework::DDim dy_dims = dy->dims ();
573
+ const framework::DDim& dy_dims = dy->dims ();
574
574
ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
575
575
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
576
576
}
0 commit comments