Skip to content

Commit 6e0da4d

Browse files
committed
chore: update doc for partitioning
Signed-off-by: Bo Wang <[email protected]>
1 parent 78e67cc commit 6e0da4d

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

core/partitioning/README.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# TRTorch Partitioning
22

3-
TRTorch partitioning phase is developed to support automatic fallback feature in TRTorch. This phase won't run by
3+
TRTorch partitioning phase is developed to support `automatic fallback` feature in TRTorch. This phase won't run by
44
default until the automatic fallback feature is enabled.
55

66
On a high level, TRTorch partitioning phase does the following:
@@ -15,6 +15,8 @@ from the user. Shapes can be calculated by running the graphs with JIT.
1515
it's still a phase in our partitioning process.
1616
- `Stitching`. Stitch all TensorRT engines with PyTorch nodes altogether.
1717

18+
Test cases for each of these components could be found [here](https://github.com/NVIDIA/TRTorch/tree/master/tests/core/partitioning).
19+
1820
Here is the brief description of functionalities of each file:
1921
- `PartitionInfo.h/cpp`: The automatic fallback APIs that is used for partitioning.
2022
- `SegmentedBlock.h/cpp`: The main data structures that is used to maintain information for each segments after segmentation.
@@ -34,8 +36,8 @@ To enable automatic fallback feature, you can set following attributes in Python
3436
...
3537
"torch_fallback" : {
3638
"enabled" : True,
37-
"min_block_size" : 1,
38-
"forced_fallback_ops": ["aten::foo"],
39+
"min_block_size" : 3,
40+
"forced_fallback_ops": ["aten::add"],
3941
}
4042
})
4143
```
@@ -58,8 +60,8 @@ auto mod = trtorch::jit::load("trt_ts_module.ts");
5860
auto input_sizes = std::vector<trtorch::CompileSpec::InputRange>{{in.sizes()}};
5961
trtorch::CompileSpec cfg(input_sizes);
6062
cfg.torch_fallback = trtorch::CompileSpec::TorchFallback(true);
61-
cfg.torch_fallback.min_block_size = 1;
62-
cfg.torch_fallback.forced_fallback_ops.push_back("aten::foo");
63+
cfg.torch_fallback.min_block_size = 2;
64+
cfg.torch_fallback.forced_fallback_ops.push_back("aten::relu");
6365
auto trt_mod = trtorch::CompileGraph(mod, cfg);
6466
auto out = trt_mod.forward({in});
6567
```

0 commit comments

Comments
 (0)