-
Notifications
You must be signed in to change notification settings - Fork 370
feat: Autocast #3878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
zewenli98
wants to merge
10
commits into
main
Choose a base branch
from
autocast
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
feat: Autocast #3878
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
eac8809
implement autocast
zewenli98 f6c7c7c
fix bug
zewenli98 f7d8068
add arg enable_autocast
zewenli98 e15ce94
change names of API and support for user specified node names
zewenli98 94757d2
support dataloader for calibration
zewenli98 4bf12e7
fix comments
zewenli98 0a62149
optimize Cast insertion logic, fix io dtype issue and comments, and a…
zewenli98 a990653
fix bugs in cpp runtime
zewenli98 3e008c2
amend doc and settings
zewenli98 a7a8039
make explicit typing as default, polish examples and docs
zewenli98 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| """ | ||
| .. _autocast_example: | ||
|
|
||
| An example of using Torch-TensorRT Autocast | ||
| ================ | ||
|
|
||
| This example demonstrates how to use Torch-TensorRT Autocast with PyTorch Autocast to compile a mixed precision model. | ||
| """ | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| import torch_tensorrt | ||
|
|
||
| # %% Mixed Precision Model | ||
| # | ||
| # We define a mixed precision model that consists of a few layers, a ``log`` operation, and an ``abs`` operation. | ||
| # Among them, the ``fc1``, ``log``, and ``abs`` operations are within PyTorch Autocast context with ``dtype=torch.float16``. | ||
|
|
||
|
|
||
| class MixedPytorchAutocastModel(nn.Module): | ||
| def __init__(self): | ||
| super(MixedPytorchAutocastModel, self).__init__() | ||
| self.conv1 = nn.Conv2d( | ||
| in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1 | ||
| ) | ||
| self.relu1 = nn.ReLU() | ||
| self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) | ||
| self.conv2 = nn.Conv2d( | ||
| in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1 | ||
| ) | ||
| self.relu2 = nn.ReLU() | ||
| self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) | ||
| self.flatten = nn.Flatten() | ||
| self.fc1 = nn.Linear(16 * 8 * 8, 10) | ||
|
|
||
| def forward(self, x): | ||
| out1 = self.conv1(x) | ||
| out2 = self.relu1(out1) | ||
| out3 = self.pool1(out2) | ||
| out4 = self.conv2(out3) | ||
| out5 = self.relu2(out4) | ||
| out6 = self.pool2(out5) | ||
| out7 = self.flatten(out6) | ||
| with torch.autocast(x.device.type, enabled=True, dtype=torch.float16): | ||
| out8 = self.fc1(out7) | ||
| out9 = torch.log( | ||
| torch.abs(out8) + 1 | ||
| ) # log is fp32 due to Pytorch Autocast requirements | ||
| return x, out1, out2, out3, out4, out5, out6, out7, out8, out9 | ||
|
|
||
|
|
||
| # %% | ||
| # Define the model, inputs, and calibration dataloader for Autocast, and then we run the original PyTorch model to get the reference outputs. | ||
|
|
||
| model = MixedPytorchAutocastModel().cuda().eval() | ||
| inputs = (torch.randn((8, 3, 32, 32), dtype=torch.float32, device="cuda"),) | ||
| ep = torch.export.export(model, inputs) | ||
| calibration_dataloader = torch.utils.data.DataLoader( | ||
| torch.utils.data.TensorDataset(*inputs), batch_size=2, shuffle=False | ||
| ) | ||
|
|
||
| pytorch_outs = model(*inputs) | ||
|
|
||
| # %% Compile the model with Torch-TensorRT Autocast | ||
| # | ||
| # We compile the model with Torch-TensorRT Autocast by setting ``enable_autocast=True``, ``use_explicit_typing=True``, and | ||
| # ``autocast_low_precision_type=torch.bfloat16``. To illustrate, we exclude the ``conv1`` node, all nodes with name | ||
| # containing ``relu``, and ``torch.ops.aten.flatten.using_ints`` ATen op from Autocast. In addtion, we also set | ||
| # ``autocast_max_output_threshold``, ``autocast_max_depth_of_reduction``, and ``autocast_calibration_dataloader``. Please refer to | ||
| # the documentation for more details. | ||
|
|
||
| trt_autocast_mod = torch_tensorrt.compile( | ||
| ep.module(), | ||
| arg_inputs=inputs, | ||
| min_block_size=1, | ||
| use_python_runtime=True, | ||
| use_explicit_typing=True, | ||
| enable_autocast=True, | ||
| autocast_low_precision_type=torch.bfloat16, | ||
| autocast_excluded_nodes={"^conv1$", "relu"}, | ||
| autocast_excluded_ops={"torch.ops.aten.flatten.using_ints"}, | ||
| autocast_max_output_threshold=512, | ||
| autocast_max_depth_of_reduction=None, | ||
| autocast_calibration_dataloader=calibration_dataloader, | ||
| ) | ||
|
|
||
| autocast_outs = trt_autocast_mod(*inputs) | ||
|
|
||
| # %% Verify the outputs | ||
| # | ||
| # We verify both the dtype and values of the outputs of the model are correct. | ||
| # As expected, ``fc1`` is in FP16 because of PyTorch Autocast; | ||
| # ``pool1``, ``conv2``, and ``pool2`` are in BFP16 because of Torch-TensorRT Autocast; | ||
| # the rest remain in FP32. Note that ``log`` is in FP32 because of PyTorch Autocast requirements. | ||
|
|
||
| should_be_fp32 = [ | ||
| autocast_outs[0], | ||
| autocast_outs[1], | ||
| autocast_outs[2], | ||
| autocast_outs[5], | ||
| autocast_outs[7], | ||
| autocast_outs[9], | ||
| ] | ||
| should_be_fp16 = [ | ||
| autocast_outs[8], | ||
| ] | ||
| should_be_bf16 = [autocast_outs[3], autocast_outs[4], autocast_outs[6]] | ||
|
|
||
| assert all( | ||
| a.dtype == torch.float32 for a in should_be_fp32 | ||
| ), "Some Autocast outputs are not float32!" | ||
| assert all( | ||
| a.dtype == torch.float16 for a in should_be_fp16 | ||
| ), "Some Autocast outputs are not float16!" | ||
| assert all( | ||
| a.dtype == torch.bfloat16 for a in should_be_bf16 | ||
| ), "Some Autocast outputs are not bfloat16!" | ||
| for i, (a, w) in enumerate(zip(autocast_outs, pytorch_outs)): | ||
| assert torch.allclose( | ||
| a.to(torch.float32), w.to(torch.float32), atol=1e-2, rtol=1e-2 | ||
| ), f"Autocast and Pytorch outputs do not match! autocast_outs[{i}] = {a}, pytorch_outs[{i}] = {w}" | ||
| print("All dtypes and values match!") | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.