Skip to content

Commit 2d14e86

Browse files
fxdawnnpytorchmergebot
authored andcommitted
[HOP][print][dynamo]Add dynamo for hop print (pytorch#167571)
Following the previous implementation of HOP print, this continues to enable HOP print for dynamo so as to enable eager full graph and aot_eager backend for torch compile. At the end of the the implementation, the HOP print is able to enable stateful print without causing graph break. In the prior built in print, dynamo is able to reduce the graph break but unable to eliminate it. This enable the format-based printing for such purpose in dynamo. Pull Request resolved: pytorch#167571 Approved by: https://github.com/angelayi ghstack dependencies: pytorch#167016
1 parent 8bb1152 commit 2d14e86

File tree

2 files changed

+88
-4
lines changed

2 files changed

+88
-4
lines changed

test/higher_order_ops/test_print.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,17 @@
33
from unittest.mock import patch
44

55
import torch
6-
from torch._dynamo.utils import counters
76
from torch._functorch.aot_autograd import aot_export_module
87
from torch.fx.experimental.proxy_tensor import make_fx
9-
from torch.testing._internal.common_utils import run_tests, TestCase
8+
from torch.testing._internal.common_utils import (
9+
instantiate_parametrized_tests,
10+
parametrize,
11+
run_tests,
12+
TestCase,
13+
)
1014

1115

16+
@instantiate_parametrized_tests
1217
class TestHopPrint(TestCase):
1318
def test_base_print(self):
1419
def f(x):
@@ -18,7 +23,6 @@ def f(x):
1823
torch._higher_order_ops.print("moo")
1924
return x
2025

21-
counters.clear()
2226
x = torch.randn(3, 3)
2327
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
2428
f(x)
@@ -33,7 +37,6 @@ def f(x):
3337
x = x * x
3438
return x
3539

36-
counters.clear()
3740
x = torch.randn(3, 3)
3841
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
3942
f(x)
@@ -184,6 +187,62 @@ def test_print_gen_schema(self):
184187
"""print(str format_str) -> ()""",
185188
)
186189

190+
@parametrize("backend", ["eager", "aot_eager"])
191+
def test_reorder_print_no_graph_break(self, backend):
192+
def f(x):
193+
x1 = x + x
194+
torch._higher_order_ops.print("moo {x}", x=x1)
195+
x2 = x1 * x1
196+
torch._higher_order_ops.print("moo {x}", x=x2)
197+
x3 = x2 + x2
198+
return (x1, x3)
199+
200+
# Eager and aot_eager backend for dynamo tracing testing
201+
x = torch.randn(3, 3)
202+
opt_f = torch.compile(backend=backend, fullgraph=True)(f)
203+
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
204+
opt_out = opt_f(x)
205+
printed_output = mock_stdout.getvalue().strip()
206+
orig_out = f(x)
207+
208+
self.assertEqual(
209+
printed_output,
210+
f"moo {x * 2}\nmoo {x * 2 * x * 2}",
211+
)
212+
self.assertEqual(orig_out, opt_out)
213+
214+
x_new = torch.randn(2, 2)
215+
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
216+
opt_out = opt_f(x_new)
217+
printed_output = mock_stdout.getvalue().strip()
218+
219+
self.assertEqual(
220+
printed_output,
221+
f"moo {x_new * 2}\nmoo {x_new * 2 * x_new * 2}",
222+
)
223+
224+
@parametrize("backend", ["eager", "aot_eager"])
225+
def test_constant_mutation(self, backend):
226+
def f(x):
227+
alist = [x]
228+
alist.append(x + 1)
229+
torch._higher_order_ops.print("moo {x}", x=alist[-1])
230+
alist[0].sum().item() # graph break
231+
res = alist.pop()
232+
torch._higher_order_ops.print("moo {x}", x=alist[-1])
233+
res.sum().item() # graph break
234+
return res
235+
236+
inputs = (torch.tensor([1]),)
237+
opt_f = torch.compile(backend=backend, fullgraph=True)(f)
238+
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
239+
opt_out = opt_f(*inputs)
240+
printed_output = mock_stdout.getvalue().strip()
241+
orig_out = f(*inputs)
242+
243+
self.assertEqual(printed_output, "moo tensor([2])\nmoo tensor([1])")
244+
self.assertEqual(orig_out, opt_out)
245+
187246

188247
if __name__ == "__main__":
189248
run_tests()

torch/_dynamo/variables/higher_order_ops.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2673,6 +2673,30 @@ def _call_function(
26732673
)
26742674

26752675

2676+
class PrintHigherOrderVariable(TorchHigherOrderOperatorVariable):
2677+
def _call_function(
2678+
self,
2679+
tx: "InstructionTranslator",
2680+
args: "list[VariableTracker]",
2681+
kwargs: "dict[str, VariableTracker]",
2682+
) -> "VariableTracker":
2683+
from .builder import wrap_fx_proxy
2684+
2685+
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
2686+
2687+
args_proxy = [arg.as_proxy() for arg in args]
2688+
kwargs_proxy = {k: v.as_proxy() for k, v in kwargs.items()}
2689+
return wrap_fx_proxy(
2690+
tx=tx,
2691+
proxy=tx.output.create_proxy(
2692+
"call_function",
2693+
self.value,
2694+
args=tuple(args_proxy),
2695+
kwargs=kwargs_proxy,
2696+
),
2697+
)
2698+
2699+
26762700
class ExecutorchCallDelegateHigherOrderVariable(TorchHigherOrderOperatorVariable):
26772701
def _call_function(
26782702
self,
@@ -4537,6 +4561,7 @@ def make_error_msg(*args):
45374561
"associative_scan": AssociativeScanHigherOrderVariable,
45384562
"scan": ScanHigherOrderVariable,
45394563
"call_torchbind": CallTorchbindHigherOrderVariable,
4564+
"print": PrintHigherOrderVariable,
45404565
"wrap_with_set_grad_enabled": WrapWithSetGradEnabledHigherOrderVariable,
45414566
"wrap_with_autocast": WrapWithAutocastHigherOrderVariable,
45424567
"dynamo_bypassing_wrapper": DynamoBypassingWrapperHigherOrderVariable,

0 commit comments

Comments
 (0)