@@ -34,7 +34,7 @@ bool check_convolution_backward_args(
3434 bool transposed,
3535 IntArrayRef output_padding,
3636 int64_t groups,
37- ET_UNUSED executorch::aten::ArrayRef<bool > output_mask,
37+ executorch::aten::ArrayRef<bool > output_mask,
3838 Tensor& grad_input,
3939 Tensor& grad_weight,
4040 Tensor& grad_bias) {
@@ -45,9 +45,18 @@ bool check_convolution_backward_args(
4545
4646 ET_LOG_AND_RETURN_IF_FALSE (tensors_have_same_dtype (weight, input));
4747 ET_LOG_AND_RETURN_IF_FALSE (tensors_have_same_dtype (grad_output, input));
48- ET_LOG_AND_RETURN_IF_FALSE (tensors_have_same_dtype (grad_input, input));
49- ET_LOG_AND_RETURN_IF_FALSE (tensors_have_same_dtype (grad_weight, input));
50- ET_LOG_AND_RETURN_IF_FALSE (tensors_have_same_dtype (grad_bias, input));
48+
49+ if (output_mask[0 ]) {
50+ ET_LOG_AND_RETURN_IF_FALSE (tensors_have_same_dtype (grad_input, input));
51+ }
52+
53+ if (output_mask[1 ]) {
54+ ET_LOG_AND_RETURN_IF_FALSE (tensors_have_same_dtype (grad_weight, input));
55+ }
56+
57+ if (output_mask[2 ]) {
58+ ET_LOG_AND_RETURN_IF_FALSE (tensors_have_same_dtype (grad_bias, input));
59+ }
5160
5261 ET_LOG_MSG_AND_RETURN_IF_FALSE (
5362 check_convolution_args (
@@ -267,19 +276,23 @@ std::tuple<Tensor&, Tensor&, Tensor&> convolution_backward_out(
267276 InvalidArgument,
268277 ret_val);
269278
270- ET_KERNEL_CHECK (
271- ctx,
272- resize_tensor (grad_input, input.sizes ()) == Error::Ok,
273- InvalidArgument,
274- ret_val);
279+ if (output_mask[0 ]) {
280+ ET_KERNEL_CHECK (
281+ ctx,
282+ resize_tensor (grad_input, input.sizes ()) == Error::Ok,
283+ InvalidArgument,
284+ ret_val);
285+ }
275286
276- ET_KERNEL_CHECK (
277- ctx,
278- resize_tensor (grad_weight, weight.sizes ()) == Error::Ok,
279- InvalidArgument,
280- ret_val);
287+ if (output_mask[1 ]) {
288+ ET_KERNEL_CHECK (
289+ ctx,
290+ resize_tensor (grad_weight, weight.sizes ()) == Error::Ok,
291+ InvalidArgument,
292+ ret_val);
293+ }
281294
282- if (bias_sizes_opt.has_value ()) {
295+ if (bias_sizes_opt.has_value () && output_mask[ 2 ] ) {
283296 ET_KERNEL_CHECK (
284297 ctx,
285298 resize_tensor (grad_bias, bias_sizes_opt.value ()) == Error::Ok,
0 commit comments