Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 442483c

Browse files
committed
fix CI issues with AOTAutograd
1 parent 0e08bf5 commit 442483c

File tree

5 files changed

+19
-19
lines changed

5 files changed

+19
-19
lines changed

functorch/_src/compile_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44
from torch.utils._pytree import tree_flatten
55

66
aten = torch.ops.aten
7+
8+
9+
def get_aten_target(node):
10+
if hasattr(node.target, 'overloadpacket'):
11+
return node.target.overloadpacket
12+
return node.target
13+
14+
715
rand_ops = [aten.dropout, aten._fused_dropout, aten._standard_gamma,
816
aten.bernoulli, aten.multinomial, aten.native_dropout,
917
aten.normal, aten.poisson, aten.binomial, aten.rrelu,
@@ -19,7 +27,7 @@ def fx_graph_cse(fx_g: torch.fx.graph.Graph):
1927
for n in fx_g.nodes:
2028
# The placeholder, output, and get_attr nodes are copied to the new grpah without change
2129
# do not CSE away random operations
22-
if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or n.target in rand_ops:
30+
if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in rand_ops:
2331
new_node = new_graph.node_copy(n, lambda x: env[x])
2432
env[n] = new_node
2533
else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'

functorch/_src/partitioners.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import os
88
from torch.fx.passes import graph_drawer
99
from typing import Tuple
10-
from .compile_utils import fx_graph_cse, strip_overloads
10+
from .compile_utils import fx_graph_cse, get_aten_target
1111

1212

1313
class InvalidNodeBase(object):
@@ -227,7 +227,6 @@ def min_cut_rematerialization_partition(
227227
except ImportError:
228228
raise RuntimeError("Need networkx installed to perform smart recomputation heuristics")
229229

230-
strip_overloads(joint_module)
231230
joint_module.graph.eliminate_dead_code()
232231
joint_module.recompile()
233232
fx_g = joint_module.graph
@@ -299,22 +298,22 @@ def classify_nodes(joint_module):
299298

300299
def ban_recomputation(node):
301300
if AGGRESSIVE_RECOMPUTATION:
302-
return (node.op == 'call_function' and node.target in unrecomputable_ops)
301+
return (node.op == 'call_function' and get_aten_target(node) in unrecomputable_ops)
303302
else:
304303
if node.op != 'call_function':
305304
return False
306-
if node.target not in recomputable_ops:
305+
if get_aten_target(node) not in recomputable_ops:
307306
return True
308307
# If the output of the reduction is 4x smaller (arbitrary choice),
309308
# then we don't allow recomputation.
310-
if node.target in reduction_ops:
309+
if get_aten_target(node) in reduction_ops:
311310
input_tensors_size = sum(_size_of(i.meta['tensor_meta']) for i in node.args if isinstance(i, fx.Node))
312311
output_size = _size_of(node.meta['tensor_meta'])
313312
return (output_size * 4 < input_tensors_size)
314313
return False
315314

316315
def is_fusible(a, b):
317-
return a.target in fusible_ops and b.target in fusible_ops
316+
return get_aten_target(a) in fusible_ops and get_aten_target(b) in fusible_ops
318317

319318
def is_materialized(node):
320319
if node.op == 'placeholder':

test/test_memory_efficient_fusion.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,6 @@ def f(x):
239239
t = torch.randn(2, 2)
240240
check(f, t, 0)
241241

242-
# https://github.com/pytorch/functorch/issues/913
243-
@unittest.expectedFailure
244242
def test_rand_like(self):
245243
def f(x):
246244
a = torch.rand_like(x)
@@ -249,17 +247,13 @@ def f(x):
249247
t = torch.randn(2, 2)
250248
check(f, t, 0, check_val=False)
251249

252-
# https://github.com/pytorch/functorch/issues/913
253-
@unittest.expectedFailure
254250
def test_rand_n(self):
255251
def f(x):
256-
g_cpu = torch.Generator()
257-
g_cpu.manual_seed(2147483647)
258-
a = torch.randn(4, generator=g_cpu)
259-
b = torch.randn(4, generator=g_cpu)
252+
a = torch.randn(4)
253+
b = torch.randn(4)
260254
return a + b
261255
t = torch.randn(2, 2)
262-
check(f, t, 0)
256+
check(f, t, 0, check_val=False)
263257

264258

265259
class ReduceTestCase(TestCase):

test/test_minifier.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
class TestMinifier(TestCase):
99
# https://github.com/pytorch/functorch/issues/913
10-
@unittest.expectedFailure
1110
def test_has_mul_minifier(self):
1211
def failing_f(x, y):
1312
y = y / 3
@@ -18,7 +17,7 @@ def failing_f(x, y):
1817
failing_f = make_fx(failing_f)(*inps)
1918

2019
def pass_checker(fx_g, inps):
21-
return (torch.ops.aten.mul in set([i.target for i in fx_g.graph.nodes]))
20+
return (torch.ops.aten.mul.Tensor in set([i.target for i in fx_g.graph.nodes]))
2221

2322
min_f, inps = minifier(failing_f, inps, pass_checker)
2423
assert len(min_f.graph.nodes) == 4

test/test_pythonkey.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def f(x):
501501
self.assertEqual(get_num_ins_outs(fw_graph), (1, 3))
502502

503503
ins, outs = get_ins_outs(fw_graph)
504-
self.assertEqual(outs[1].target, torch.ops.aten.mm)
504+
self.assertEqual(outs[1].target, torch.ops.aten.mm.default)
505505

506506

507507
class TestContiguous(TestCase):

0 commit comments

Comments
 (0)