Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 8a60047

Browse files
committed
Fix undefined bias for conv batching rule
Fixes part of #338
1 parent 00ce7dc commit 8a60047

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

functorch/csrc/BatchRulesModules.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ convolution_batch_rule(const Tensor& lhs, optional<int64_t> lhs_bdim, const Tens
2323
// If we have a batched bias or weight, we need to perform the computation separately.
2424
optional<Tensor> unbatched_bias;
2525
bool separate_bias;
26-
if ((rhs_bdim && bias) || bias_bdim) {
26+
if ((rhs_bdim && bias && bias->defined()) || bias_bdim) {
2727
TORCH_INTERNAL_ASSERT(bias.has_value());
28+
TORCH_INTERNAL_ASSERT(bias->defined());
2829
unbatched_bias = nullopt;
2930
separate_bias = true;
3031
} else {

test/test_ops.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,6 @@ def vjp_of_vjp(*args_and_cotangents):
469469

470470
# PyTorch changed its convolution recently.
471471
# Maybe it is responsible for all of the following changes.
472-
xfail('nn.functional.conv1d'),
473472
xfail('nn.functional.conv_transpose1d'),
474473
xfail('nn.functional.conv_transpose2d'),
475474
xfail('nn.functional.conv_transpose3d'),
@@ -479,7 +478,7 @@ def vjp_of_vjp(*args_and_cotangents):
479478
@skipOps('TestOperators', 'test_vmapvjp', vmapvjp_fail)
480479
def test_vmapvjp(self, device, dtype, op):
481480
# These are too annoying to put into the list above
482-
if op.name in {'nn.functional.linear', 'nn.functional.conv2d'}:
481+
if op.name in {'nn.functional.linear'}:
483482
self.skipTest("Skipped! ExpectedF failures")
484483
if not op.supports_autograd:
485484
self.skipTest("Skipped! Autograd not supported.")

0 commit comments

Comments
 (0)