|
1 | 1 | import logging |
2 | 2 |
|
3 | 3 | import torch |
| 4 | +from torch_tensorrt.dynamo._settings import CompilationSettings |
4 | 5 | from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( |
5 | 6 | clean_up_graph_after_modifications, |
6 | 7 | ) |
7 | 8 |
|
8 | 9 | logger = logging.getLogger(__name__) |
9 | 10 |
|
10 | 11 |
|
11 | | -def accumulate_fp32_matmul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 12 | +def accumulate_fp32_matmul( |
| 13 | + gm: torch.fx.GraphModule, settings: CompilationSettings |
| 14 | +) -> torch.fx.GraphModule: |
12 | 15 | """Replace a matmul layer with fp32 accumulation nodes""" |
13 | | - matmul_targets = [ |
14 | | - torch.ops.aten.mm.default, |
15 | | - torch.ops.aten.bmm.default, |
16 | | - torch.ops.aten.addmm.default, |
17 | | - ] |
18 | | - matmul_nodes = [node for node in gm.graph.nodes if node.target in matmul_targets] |
19 | | - for matmul_node in matmul_nodes: |
20 | | - # Prior to the matmul node, insert a cast to the 32-bit float32 node |
21 | | - node_inputs = matmul_node.all_input_nodes |
22 | | - |
23 | | - for node_input in node_inputs: |
24 | | - with gm.graph.inserting_before(matmul_node): |
25 | | - node_32bit = gm.graph.call_function( |
| 16 | + if settings.use_fp32_acc: |
| 17 | + matmul_targets = [ |
| 18 | + torch.ops.aten.mm.default, |
| 19 | + torch.ops.aten.bmm.default, |
| 20 | + torch.ops.aten.addmm.default, |
| 21 | + ] |
| 22 | + |
| 23 | + matmul_nodes = [ |
| 24 | + node for node in gm.graph.nodes if node.target in matmul_targets |
| 25 | + ] |
| 26 | + for matmul_node in matmul_nodes: |
| 27 | + # Prior to the matmul node, insert a cast to the 32-bit float32 node |
| 28 | + node_inputs = matmul_node.all_input_nodes |
| 29 | + |
| 30 | + for node_input in node_inputs: |
| 31 | + with gm.graph.inserting_before(matmul_node): |
| 32 | + node_32bit = gm.graph.call_function( |
| 33 | + torch.ops.aten._to_copy.default, |
| 34 | + args=(node_input,), |
| 35 | + kwargs={"dtype": torch.float32}, |
| 36 | + ) |
| 37 | + |
| 38 | + # Replace the input to matmul node with new 32-bit cast node |
| 39 | + matmul_node.replace_input_with(node_input, node_32bit) |
| 40 | + |
| 41 | + # Add a cast back to original precision |
| 42 | + with gm.graph.inserting_after(matmul_node): |
| 43 | + node_orig_precision = gm.graph.call_function( |
26 | 44 | torch.ops.aten._to_copy.default, |
27 | | - args=(node_input,), |
28 | | - kwargs={"dtype": torch.float32}, |
| 45 | + args=(matmul_node,), |
| 46 | + kwargs={"dtype": torch.float16}, |
29 | 47 | ) |
| 48 | + matmul_node.replace_all_uses_with( |
| 49 | + node_orig_precision, propagate_meta=False |
| 50 | + ) |
| 51 | + # This is a hack. replace_all_uses_with isn't working here. It complains node_orig_precision is already being used before created. |
| 52 | + node_orig_precision.replace_input_with( |
| 53 | + node_orig_precision.all_input_nodes[0], matmul_node |
| 54 | + ) |
| 55 | + |
| 56 | + gm = clean_up_graph_after_modifications(gm) |
| 57 | + logger.debug( |
| 58 | + f"Graph after enabling matmul layers to use FP32 accumulation:\n{gm.graph}" |
| 59 | + ) |
| 60 | + else: |
| 61 | + logger.debug( |
| 62 | + "Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings" |
| 63 | + ) |
30 | 64 |
|
31 | | - # Replace the input to matmul node with new 32-bit cast node |
32 | | - matmul_node.replace_input_with(node_input, node_32bit) |
33 | | - |
34 | | - # Add a cast back to original precision |
35 | | - with gm.graph.inserting_after(matmul_node): |
36 | | - node_orig_precision = gm.graph.call_function( |
37 | | - torch.ops.aten._to_copy.default, |
38 | | - args=(matmul_node,), |
39 | | - kwargs={"dtype": torch.float16}, |
40 | | - ) |
41 | | - matmul_node.replace_all_uses_with(node_orig_precision, propagate_meta=False) |
42 | | - # This is a hack. replace_all_uses_with isn't working here. It complains node_orig_precision is already being used before created. |
43 | | - node_orig_precision.replace_input_with( |
44 | | - node_orig_precision.all_input_nodes[0], matmul_node |
45 | | - ) |
46 | | - |
47 | | - gm = clean_up_graph_after_modifications(gm) |
48 | | - logger.debug(f"Graph after changing matmuls to use FP32 accumulation:\n{gm.graph}") |
49 | 65 | return gm |
0 commit comments