Skip to content

Commit b503d47

Browse files
jfix71Wei Wei
authored andcommitted
[acc_shape_prop] Introduce and use for acc_tracer to support fp16 sample inputs (#73)
Summary: Pull Request resolved: pytorch/fx2trt#73 Tries to support shape prop for fp16 ops that don’t have pytorch CPU support. Does so by first attempting to use standard shape_prop, and if it fails then upconverts fp16 inputs to fp32 to re-run. This should make things much cleaner for acc_tracer, as the user can provide fp16 sample inputs directly instead of fp32 and then hacking things after the fact. Reviewed By: alexbeloi Differential Revision: D36305442 fbshipit-source-id: 2ecdc88a072d914cb26785d29fd7e409c51954fb
1 parent e610087 commit b503d47

File tree

5 files changed

+202
-16
lines changed

5 files changed

+202
-16
lines changed

test/tracer/test_acc_shape_prop.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Owner(s): ["oncall: fx"]
2+
3+
import operator
4+
import unittest
5+
6+
import fx2trt_oss.tracer.acc_tracer.acc_shape_prop as acc_shape_prop
7+
import fx2trt_oss.tracer.acc_tracer.acc_tracer as acc_tracer
8+
import torch
9+
from parameterized import parameterized, param
10+
11+
torch.manual_seed(0)
12+
13+
14+
class AccShapePropTest(unittest.TestCase):
15+
@parameterized.expand(
16+
[
17+
param("fp32", dtype=torch.float32),
18+
param("fp16", dtype=torch.float16),
19+
]
20+
)
21+
def test_basic(self, _, dtype):
22+
class TestModule(torch.nn.Module):
23+
def __init__(self):
24+
super().__init__()
25+
self.attr = torch.nn.Parameter(torch.randn(3, 4))
26+
self.submod = torch.nn.Linear(4, 4)
27+
28+
def forward(self, x):
29+
return torch.neg(self.submod(x.relu() + self.attr))
30+
31+
m = TestModule()
32+
if dtype == torch.float16:
33+
m.half()
34+
gm = acc_tracer.rewriter_base_trace(m, None, None)
35+
inp = torch.rand(3, 4, dtype=dtype)
36+
acc_shape_prop.AccShapeProp(gm).propagate(inp)
37+
38+
for node in gm.graph.nodes:
39+
self.assertEqual(node.meta["tensor_meta"].dtype, dtype)
40+
41+
def test_mutli_dtype(self):
42+
class TestModule(torch.nn.Module):
43+
def forward(self, x, y):
44+
return torch.relu(x * 2), torch.sigmoid(y + y)
45+
46+
m = TestModule()
47+
gm = acc_tracer.rewriter_base_trace(m, None, None)
48+
# Note: One input is fp32, the other fp16.
49+
x, y = torch.rand(3, 4), torch.rand(3, 4, dtype=torch.float16)
50+
acc_shape_prop.AccShapeProp(gm).propagate(x, y)
51+
52+
for node in gm.graph.nodes:
53+
if (node.op == "placeholder" and node.target == "x") or (
54+
node.op == "call_function" and node.target in {operator.mul, torch.relu}
55+
):
56+
self.assertEqual(node.meta["tensor_meta"].dtype, torch.float32)
57+
elif node.op != "output":
58+
self.assertEqual(node.meta["tensor_meta"].dtype, torch.float16)
59+
else:
60+
self.assertEqual(node.meta["tensor_meta"][0].dtype, torch.float32)
61+
self.assertEqual(node.meta["tensor_meta"][1].dtype, torch.float16)
62+
63+
def test_to_dtype(self):
64+
class TestModule(torch.nn.Module):
65+
def forward(self, x):
66+
return x.to(dtype=torch.float32).to(dtype=torch.float16)
67+
68+
m = TestModule()
69+
gm = acc_tracer.rewriter_base_trace(m, None, None)
70+
x = torch.rand(3, 4, dtype=torch.float16)
71+
acc_shape_prop.AccShapeProp(gm).propagate(x)
72+
ph = None
73+
for node in gm.graph.nodes:
74+
if node.op == "placeholder":
75+
ph = node
76+
self.assertEqual(node.meta["tensor_meta"].dtype, torch.float16)
77+
elif node.all_input_nodes == [ph]:
78+
self.assertEqual(node.meta["tensor_meta"].dtype, torch.float32)
79+
else:
80+
self.assertEqual(node.meta["tensor_meta"].dtype, torch.float16)
81+
82+
def test_split(self):
83+
class TestModule(torch.nn.Module):
84+
def forward(self, x):
85+
s = torch.tensor_split(x, 2)
86+
return s[0].relu(), s[1].sigmoid()
87+
88+
m = TestModule()
89+
gm = acc_tracer.rewriter_base_trace(m, None, None)
90+
x = torch.rand(2, 4, dtype=torch.float16)
91+
acc_shape_prop.AccShapeProp(gm).propagate(x)
92+
for node in gm.graph.nodes:
93+
if node.target == torch.tensor_split or node.op == "output":
94+
self.assertEqual(node.meta["tensor_meta"][0].dtype, torch.float16)
95+
self.assertEqual(node.meta["tensor_meta"][1].dtype, torch.float16)
96+
else:
97+
self.assertEqual(node.meta["tensor_meta"].dtype, torch.float16)

test/tracer/test_acc_tracer.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1947,15 +1947,20 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
19471947

19481948
torch.testing.assert_allclose(m(input), traced(input))
19491949

1950-
def test_addmm(self):
1950+
@parameterized.expand([(torch.float,), (torch.float16,)])
1951+
def test_addmm(self, dtype):
19511952
class TestModule(torch.nn.Module):
19521953
def forward(
19531954
self, input: torch.Tensor, a: torch.Tensor, b: torch.Tensor
19541955
) -> torch.Tensor:
19551956
return torch.addmm(input, a, b)
19561957

19571958
m = TestModule()
1958-
input, a, b = torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2)
1959+
input, a, b = (
1960+
torch.randn(2, 2, dtype=dtype),
1961+
torch.randn(2, 2, dtype=dtype),
1962+
torch.randn(2, 2, dtype=dtype),
1963+
)
19591964
traced = acc_tracer.trace(m, [input, a, b])
19601965

19611966
ph_in = ph_a = ph_b = mm = add = None
@@ -1983,7 +1988,11 @@ def forward(
19831988
else:
19841989
self.fail(f"Unexpected node: {node.format_node()}")
19851990

1986-
self.assertTrue(torch.equal(m(input, a, b), traced(input, a, b)))
1991+
for node in [ph_in, ph_a, ph_b, mm, add]:
1992+
self.assertEqual(acc_utils.get_tensor_meta(node).dtype, dtype)
1993+
1994+
if dtype == torch.float:
1995+
self.assertTrue(torch.equal(m(input, a, b), traced(input, a, b)))
19871996

19881997
def test_gelu(self):
19891998
return self._make_acc_op_function_test(acc_ops.gelu, torch.nn.functional.gelu)

tracer/acc_tracer/acc_shape_prop.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import os
2+
import sys
3+
from typing import Any
4+
5+
import torch.fx
6+
from torch.fx.passes import shape_prop
7+
8+
9+
class SuppressStderrPrints:
10+
def __enter__(self):
11+
self._original_stderr = sys.stderr
12+
sys.stderr = open(os.devnull, "w")
13+
14+
def __exit__(self, exc_type, exc_val, exc_tb):
15+
sys.stderr.close()
16+
sys.stderr = self._original_stderr
17+
18+
19+
class AccShapeProp(shape_prop.ShapeProp):
20+
"""
21+
Similar to standard shape prop, but if any node that is run with standard shape prop
22+
fails then it tries to upconvert any fp16 inputs to fp32, rerun shape prop, and then
23+
downconvert fp32 results back to fp16.
24+
25+
Note that we currently mostly only look for/support up/down conversion for nodes
26+
with tensor outputs, but this is likely fine for most cases. Additionally the base
27+
shape_prop works for many ops with fp16, such as tensor.cat, tensor slice, tensor.to
28+
dtype conversion, etc.
29+
30+
"""
31+
32+
def run_node(self, n: torch.fx.Node) -> Any:
33+
# First try running shape_prop with the original inputs.
34+
with SuppressStderrPrints():
35+
try:
36+
return super().run_node(n)
37+
except Exception:
38+
pass
39+
40+
# Base shape_prop failed, so temporarily upconvert the node's fp16 inputs in env
41+
# and retry. For now just support upconverting Tensor outputs.
42+
orig_dtype_env = []
43+
for in_node in n.all_input_nodes:
44+
in_ten = self.env[in_node]
45+
if isinstance(in_ten, torch.Tensor) and in_ten.dtype == torch.float16:
46+
orig_dtype_env.append((in_node, in_ten))
47+
self.env[in_node] = in_ten.clone().to(dtype=torch.float)
48+
49+
# Now try running again with upconverted fp32 input tensor in env.
50+
result = super().run_node(n)
51+
52+
# Now that we succeeded, assume it's thanks to upconverting. Therefore we
53+
# downconvert fp32 tensor results to fp16.
54+
if isinstance(result, torch.Tensor) and result.dtype == torch.float:
55+
result = result.to(dtype=torch.float16)
56+
self.env[n] = result
57+
n.meta["tensor_meta"] = n.meta["tensor_meta"]._replace(dtype=torch.float16)
58+
59+
# Finally, restore the original env back to fp16 for any upconverted tensors.
60+
for in_node, in_ten in orig_dtype_env:
61+
self.env[in_node] = in_ten
62+
63+
return result

tracer/acc_tracer/acc_tracer.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import fx2trt_oss.tracer.acc_tracer.acc_normalizer as acc_normalizer
1212
import fx2trt_oss.tracer.acc_tracer.acc_ops # noqa: F401
13+
import fx2trt_oss.tracer.acc_tracer.acc_shape_prop as acc_shape_prop
1314
import fx2trt_oss.tracer.acc_tracer.acc_utils as acc_utils
1415
import torch
1516
import torch.jit as jit
@@ -384,6 +385,19 @@ def _replace_tensor_meta_with_rank(gm: torch.fx.GraphModule):
384385
del node.meta["tensor_meta"]
385386

386387

388+
def rewriter_base_trace(mod, ast_rewriter_allow_list, leaf_module_list):
389+
rewritten_graph, rewritten_mod = AccRewritingTracer().trace(
390+
mod,
391+
ast_rewriter_allow_list=ast_rewriter_allow_list,
392+
leaf_module_list=leaf_module_list,
393+
)
394+
395+
assert isinstance(rewritten_mod, nn.Module)
396+
# Note: use the rewritten_mod here as the root. This is necessary because
397+
# RewrittenModule includes a new module for the ConditionalExceptionWrapper.
398+
return torch.fx.GraphModule(rewritten_mod, rewritten_graph)
399+
400+
387401
def trace(
388402
mod: nn.Module,
389403
sample_inputs: Sequence[Any],
@@ -443,18 +457,10 @@ def trace(
443457
)
444458
mod.eval()
445459

446-
# Rewrite the module to make it symbolic traceable, and then trace it.
447-
rewritten_graph, rewritten_mod = AccRewritingTracer().trace(
448-
mod,
449-
ast_rewriter_allow_list=ast_rewriter_allow_list,
450-
leaf_module_list=leaf_module_list,
451-
)
452-
453-
assert isinstance(rewritten_mod, nn.Module)
454460
assert isinstance(sample_inputs, (list, tuple))
455-
# Note: use the rewritten_mod here as the root. This is necessary because
456-
# RewrittenModule includes a new module for the ConditionalExceptionWrapper.
457-
traced = torch.fx.GraphModule(rewritten_mod, rewritten_graph)
461+
462+
# Rewrite the module to make it symbolic traceable, and then trace it.
463+
traced = rewriter_base_trace(mod, ast_rewriter_allow_list, leaf_module_list)
458464

459465
# Now remove all assertions and exceptions if requested.
460466
if remove_assertions:
@@ -467,7 +473,7 @@ def trace(
467473
traced.graph.eliminate_dead_code()
468474

469475
# Run shape prop to add node.meta["type"] to nodes, needed for NormalizeArgs.
470-
shape_prop.ShapeProp(traced).propagate(*sample_inputs)
476+
acc_shape_prop.AccShapeProp(traced).propagate(*sample_inputs)
471477
# Swap out tensor_meta for tensor_rank, because we don't actually want to rely on
472478
# tensor_meta yet for normalization/lowering, though rank shouldn't change.
473479
_replace_tensor_meta_with_rank(traced)
@@ -483,6 +489,6 @@ def trace(
483489
traced.recompile()
484490

485491
# Run shape prop to again to populate tensor_meta after normalize.
486-
shape_prop.ShapeProp(traced).propagate(*sample_inputs)
492+
acc_shape_prop.AccShapeProp(traced).propagate(*sample_inputs)
487493

488494
return traced

tracer/acc_tracer/acc_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,14 @@ def map_tensor_metadata(a: Any, fn: Callable):
189189
a, list
190190
), f"Only supporting tuple/list/TensorMetadata, but found {type(a)}"
191191
return immutable_list(map_tensor_metadata(elem, fn) for elem in a)
192+
193+
194+
def get_tensor_meta(node: torch.fx.Node) -> TensorMetadata:
195+
tensor_meta = node.meta.get("tensor_meta")
196+
197+
if not tensor_meta:
198+
raise RuntimeError(
199+
f"Node has no tensor metadata associated with it! "
200+
f"Check that shape propagation has run. {node.format_node()}"
201+
)
202+
return tensor_meta

0 commit comments

Comments
 (0)