Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit c2e72cc

Browse files
committed
Did some dead code elimination before partitioning
1 parent d11b5c2 commit c2e72cc

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

functorch/_src/compilers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def nnc_jit(f, static_argnums=None):
266266
aten.hardswish,
267267
aten.hardsigmoid,
268268
}
269+
269270
default_decompositions = get_decompositions(default_decompositions)
270271

271272

functorch/_src/partitioners.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,12 @@ def min_cut_rematerialization_partition(
227227
except ImportError:
228228
raise RuntimeError("Need networkx installed to perform smart recomputation heuristics")
229229

230-
# add the CSE pass
231230
strip_overloads(joint_module)
231+
joint_module.graph.eliminate_dead_code()
232+
joint_module.recompile()
232233
fx_g = joint_module.graph
234+
235+
# add the CSE pass
233236
cse_graph = fx_graph_cse(fx_g)
234237
joint_module.graph = cse_graph
235238
full_bw_graph = joint_module.graph
@@ -373,8 +376,6 @@ def get_node_weight(node):
373376
# To make this stuff deterministic
374377
node_idx = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
375378
saved_values = sorted((name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x])
376-
saved_values = [name_to_node[node] for node in cut_nodes]
377-
378379
return _extract_fwd_bwd_modules(joint_module, saved_values)
379380

380381

test/test_pythonkey.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,6 @@ class TestEagerFusionOpInfo(TestCase):
314314
xfail('linalg.cholesky'),
315315
skip('msort'),
316316
xfail('nn.functional.dropout'),
317-
xfail('polar'),
318317
xfail('to_sparse'),
319318
xfail('addcdiv'),
320319
xfail('cholesky'),
@@ -326,6 +325,8 @@ class TestEagerFusionOpInfo(TestCase):
326325
xfail('matrix_exp'),
327326
xfail('trapezoid'),
328327
xfail('trapz'),
328+
xfail('corrcoef'),
329+
xfail('cov'),
329330
skip('nn.functional.binary_cross_entropy_with_logits'), # seems to fail sometimes?
330331
skip('nn.functional.margin_ranking_loss'), # seems flaky
331332
})

0 commit comments

Comments
 (0)