Skip to content

Commit 017d857

Browse files
pytorchbotdolpmLucaskabela
authored
fix pickling for BitwiseFn (pytorch#163861)
* fix pickling for BitwiseFn (pytorch#163571) Summary: ran into AttributeError: Can't get local object 'make_opaque_bitwise_fn.<locals>.BitwiseFn' looks like it was fixed for UnaryFn but not BitwiseFn in pytorch#138395 Fixes pytorch#147841 Pull Request resolved: pytorch#163571 Approved by: https://github.com/jamesjwu (cherry picked from commit cde5c9a) * Fix lintrunner with -a --------- Co-authored-by: dolpm <[email protected]> Co-authored-by: Lucas Kabela <[email protected]>
1 parent d6e8411 commit 017d857

File tree

3 files changed

+25
-11
lines changed

3 files changed

+25
-11
lines changed

test/inductor/test_compile_subprocess.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ def model_add(x, y):
206206

207207
start = time.time()
208208
last_report = start
209-
while _AsyncFxCompile._stat_compiled_runs < 4:
209+
while True:
210+
start_stat_compiled_runs = _AsyncFxCompile._stat_compiled_runs
210211
# Sleep a bit so we don't drive the CPU unnecessarily.
211212
time.sleep(0.25)
212213

@@ -219,6 +220,9 @@ def model_add(x, y):
219220
# Backward pass
220221
output.sum().backward()
221222

223+
if _AsyncFxCompile._stat_compiled_runs - start_stat_compiled_runs == 2:
224+
break
225+
222226
# DEBUGGING: Print a periodic message so we know we're still
223227
# running...
224228
now = time.time()
@@ -231,12 +235,12 @@ def model_add(x, y):
231235
"Test timed out before producing a compiled artifact."
232236
)
233237

234-
self.assertEqual(_AsyncFxCompile._stat_compiled_runs, 4)
238+
self.assertGreater(_AsyncFxCompile._stat_compiled_runs, 1)
235239
# Make sure we ran eager at least once. Normally this will be
236240
# something like 80.
237241
self.assertGreater(_AsyncFxCompile._stat_eager_runs, 0)
238-
self.assertEqual(_AsyncFxCompile._stat_bg_started, 1)
239-
self.assertEqual(_AsyncFxCompile._stat_bg_finished, 1)
242+
self.assertEqual(_AsyncFxCompile._stat_bg_started, 2)
243+
self.assertEqual(_AsyncFxCompile._stat_bg_finished, 2)
240244

241245

242246
if RUN_CPU:

test/test_sympy_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import math
66
import pickle
77
import sys
8-
from typing import Callable
8+
from collections.abc import Callable
99

1010
import sympy
1111

@@ -24,6 +24,7 @@
2424
FloorDiv,
2525
Identity,
2626
OpaqueUnaryFn_cos,
27+
BitwiseFn_bitwise_and,
2728
simple_floordiv_gcd,
2829
)
2930
from torch.utils._sympy.interp import sympy_interp
@@ -426,7 +427,7 @@ def test_interp(self, fn):
426427
# Yes, I know this is a long-winded way of saying xreplace; the
427428
# point is to test sympy_interp
428429
r = sympy_interp(
429-
ReferenceAnalysis, dict(zip(symbols, sargs)), sympy_expr
430+
ReferenceAnalysis, dict(zip(symbols, sargs, strict=False)), sympy_expr
430431
)
431432
self.assertEqual(ref_r, r)
432433

@@ -501,7 +502,7 @@ def trace_f(px, py):
501502

502503
self.assertEqual(
503504
sympy_interp(
504-
PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr
505+
PythonReferenceAnalysis, dict(zip(symbols, args, strict=False)), sympy_expr
505506
),
506507
gm(*args),
507508
)
@@ -555,7 +556,7 @@ def test_tensor_interp(self, fn):
555556
direct_result = tensor_fn(*tensor_args)
556557
interp_result = sympy_interp(
557558
TensorReferenceAnalysis,
558-
dict(zip(symbols, tensor_args)),
559+
dict(zip(symbols, tensor_args, strict=False)),
559560
sympy_expr,
560561
)
561562

@@ -873,6 +874,10 @@ def test_pickle(self):
873874
r = pickle.loads(pickle.dumps(x))
874875
self.assertEqual(x, r)
875876

877+
x = BitwiseFn_bitwise_and(sympy.Symbol("a"), sympy.Symbol("b"))
878+
r = pickle.loads(pickle.dumps(x))
879+
self.assertEqual(x, r)
880+
876881

877882
class TestSingletonInt(TestCase):
878883
def test_basic(self):

torch/utils/_sympy/functions.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import math
44
import operator
55
import sys
6-
from typing import Callable, Optional, SupportsFloat, TYPE_CHECKING, TypeVar, Union
6+
from collections.abc import Callable
7+
from typing import Optional, SupportsFloat, TYPE_CHECKING, TypeVar, Union
78
from typing_extensions import TypeVarTuple, Unpack
89

910
import sympy
@@ -1192,7 +1193,8 @@ def eval(cls, *args):
11921193
# When all strides are integral, we can sort, and the size for the
11931194
# largest stride doesn't matter and can be arbitrarily symbolic
11941195
s_sizes, s_strides = zip(
1195-
*sorted(zip(sizes, strides), key=operator.itemgetter(1))
1196+
*sorted(zip(sizes, strides, strict=False), key=operator.itemgetter(1)),
1197+
strict=False,
11961198
)
11971199
# Put something arbitrary in the max size spot, it'll be ignored
11981200
if all(isinstance(a, sympy.Integer) for a in s_sizes[:-1]):
@@ -1411,7 +1413,10 @@ def eval(cls, a, b):
14111413
return sympy.Integer(getattr(operator, real_op_name)(int(a), int(b)))
14121414
return None
14131415

1414-
BitwiseFn.__name__ = "BitwiseFn_" + name
1416+
nm = "BitwiseFn_" + name
1417+
BitwiseFn.__name__ = nm
1418+
BitwiseFn.__qualname__ = nm
1419+
14151420
return BitwiseFn
14161421

14171422

0 commit comments

Comments
 (0)