This repository was archived by the owner on Aug 21, 2025. It is now read-only.
File tree Expand file tree Collapse file tree 3 files changed +7
-4
lines changed Expand file tree Collapse file tree 3 files changed +7
-4
lines changed Original file line number Diff line number Diff line change @@ -266,6 +266,7 @@ def nnc_jit(f, static_argnums=None):
266
266
aten .hardswish ,
267
267
aten .hardsigmoid ,
268
268
}
269
+
269
270
default_decompositions = get_decompositions (default_decompositions )
270
271
271
272
Original file line number Diff line number Diff line change @@ -227,9 +227,12 @@ def min_cut_rematerialization_partition(
227
227
except ImportError :
228
228
raise RuntimeError ("Need networkx installed to perform smart recomputation heuristics" )
229
229
230
- # add the CSE pass
231
230
strip_overloads (joint_module )
231
+ joint_module .graph .eliminate_dead_code ()
232
+ joint_module .recompile ()
232
233
fx_g = joint_module .graph
234
+
235
+ # add the CSE pass
233
236
cse_graph = fx_graph_cse (fx_g )
234
237
joint_module .graph = cse_graph
235
238
full_bw_graph = joint_module .graph
@@ -373,8 +376,6 @@ def get_node_weight(node):
373
376
# To make this stuff deterministic
374
377
node_idx = {node : idx for idx , node in enumerate (joint_module .graph .nodes )}
375
378
saved_values = sorted ((name_to_node [node ] for node in cut_nodes ), key = lambda x : node_idx [x ])
376
- saved_values = [name_to_node [node ] for node in cut_nodes ]
377
-
378
379
return _extract_fwd_bwd_modules (joint_module , saved_values )
379
380
380
381
Original file line number Diff line number Diff line change @@ -314,7 +314,6 @@ class TestEagerFusionOpInfo(TestCase):
314
314
xfail ('linalg.cholesky' ),
315
315
skip ('msort' ),
316
316
xfail ('nn.functional.dropout' ),
317
- xfail ('polar' ),
318
317
xfail ('to_sparse' ),
319
318
xfail ('addcdiv' ),
320
319
xfail ('cholesky' ),
@@ -326,6 +325,8 @@ class TestEagerFusionOpInfo(TestCase):
326
325
xfail ('matrix_exp' ),
327
326
xfail ('trapezoid' ),
328
327
xfail ('trapz' ),
328
+ xfail ('corrcoef' ),
329
+ xfail ('cov' ),
329
330
skip ('nn.functional.binary_cross_entropy_with_logits' ), # seems to fail sometimes?
330
331
skip ('nn.functional.margin_ranking_loss' ), # seems flaky
331
332
})
You can’t perform that action at this time.
0 commit comments