Skip to content

Commit 39b2769

Browse files
pianpwkfacebook-github-bot
authored andcommitted
[export] maybe fix conv.backward for joint graph export (#5450)
Summary: Pull Request resolved: pytorch/executorch#5450 Differential Revision: D62910149
1 parent f6f1504 commit 39b2769

File tree

3 files changed

+39
-11
lines changed

3 files changed

+39
-11
lines changed

test/export/test_experimental.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,40 @@ def forward(self, x):
327327
)
328328
joint_ep = _export_forward_backward(ep)
329329

330+
def test_joint_cifar10_backwards(self) -> None:
331+
import torch.nn as nn
332+
import torch.nn.functional as F
333+
334+
# From Pytorch's CIFAR10 example:
335+
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
336+
class Net(nn.Module):
337+
def __init__(self):
338+
super().__init__()
339+
self.conv1 = nn.Conv2d(3, 6, 5)
340+
self.pool = nn.MaxPool2d(2, 2)
341+
self.conv2 = nn.Conv2d(6, 16, 5)
342+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
343+
self.fc2 = nn.Linear(120, 84)
344+
self.fc3 = nn.Linear(84, 10)
345+
self.loss = nn.CrossEntropyLoss()
346+
347+
def forward(self, x, labels):
348+
x = self.pool(F.relu(self.conv1(x)))
349+
x = self.pool(F.relu(self.conv2(x)))
350+
x = torch.flatten(x, 1) # flatten all dimensions except batch
351+
x = F.relu(self.fc1(x))
352+
x = F.relu(self.fc2(x))
353+
x = self.fc3(x)
354+
return self.loss(x, labels)
355+
356+
net = Net()
357+
x = torch.randn(4, 3, 32, 32)
358+
labels = torch.ones(4, dtype=torch.int64)
359+
inputs = (x, labels)
360+
361+
ep = export(net, inputs)
362+
_export_forward_backward(ep)
363+
330364

331365
if __name__ == "__main__":
332366
run_tests()

torch/_functorch/aot_autograd.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1295,7 +1295,8 @@ def flattened_joint(*args):
12951295
assert grad is None
12961296
return *fw_outs, *output_gradients
12971297

1298-
fx_g = make_fx(flattened_joint)(*full_args)
1298+
flattened_joint._orig_mod = fx_g
1299+
fx_g = make_fx(flattened_joint, record_module_stack=True)(*full_args)
12991300

13001301
user_args_flat = pytree.arg_tree_leaves(*args, **kwargs)
13011302
return fx_g, create_graph_signature(

torch/_meta_registrations.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3110,16 +3110,9 @@ def meta_convolution_backward(
31103110
):
31113111
# High level logic taken from slow_conv3d_backward_cpu which should
31123112
# be representative of all convolution_backward impls
3113-
backend_grad_input = None
3114-
backend_grad_weight = None
3115-
backend_grad_bias = None
3116-
3117-
if output_mask[0]:
3118-
backend_grad_input = grad_output_.new_empty(input_.size())
3119-
if output_mask[1]:
3120-
backend_grad_weight = grad_output_.new_empty(weight_.size())
3121-
if output_mask[2]:
3122-
backend_grad_bias = grad_output_.new_empty(bias_sizes_opt)
3113+
backend_grad_input = grad_output_.new_empty(input_.size())
3114+
backend_grad_weight = grad_output_.new_empty(weight_.size())
3115+
backend_grad_bias = grad_output_.new_empty(bias_sizes_opt)
31233116

31243117
return (backend_grad_input, backend_grad_weight, backend_grad_bias)
31253118

0 commit comments

Comments
 (0)