Skip to content

Commit fe0a9aa

Browse files
authored
Merge pull request #97 from NVIDIA/ptq_reorganization
C++ API Refactor
2 parents d06a2b1 + 8445e81 commit fe0a9aa

File tree

96 files changed

+5291
-7943
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

96 files changed

+5291
-7943
lines changed

cpp/api/include/trtorch/logging.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
/*
2+
* Copyright (c) NVIDIA Corporation.
3+
* All rights reserved.
4+
*
5+
* This library is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
18
#pragma once
29

310
#include <string>

cpp/api/include/trtorch/macros.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
/*
2+
* Copyright (c) NVIDIA Corporation.
3+
* All rights reserved.
4+
*
5+
* This library is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
18
#pragma once
29

310
#if defined(__GNUC__)

cpp/api/include/trtorch/ptq.h

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,30 @@
1+
/*
2+
* Copyright (c) NVIDIA Corporation.
3+
* All rights reserved.
4+
*
5+
* This library is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
18
#pragma once
29

310
#include <string>
411
#include <vector>
512
#include <memory>
613
#include <iostream>
14+
#include <fstream>
15+
#include <iterator>
716
#include <sstream>
817

18+
#include "torch/torch.h"
919
#include "trtorch/logging.h"
20+
#include "NvInfer.h"
1021

1122
#ifndef DOXYGEN_SHOULD_SKIP_THIS
1223
namespace nvinfer1 {
1324
class IInt8Calibrator;
1425
class IInt8EntropyCalibrator2;
1526
}
1627

17-
namespace torch {
18-
class Tensor;
19-
}
20-
2128
namespace trtorch {
2229
namespace ptq {
2330
bool get_batch_impl(void* bindings[], const char* names[], int nbBindings, torch::Tensor& data);
@@ -269,5 +276,49 @@ class Int8CacheCalibrator : Algorithm {
269276
std::vector<char> cache_;
270277
};
271278

279+
/**
280+
* @brief A factory to build a post training quantization calibrator from a torch dataloader
281+
*
282+
* Creates a calibrator to use for post training quantization. By default the returned calibrator uses TensorRT Entropy v2
283+
* algorithm to perform calibration. This is recommended for feed forward networks. You can override the algorithm selection
284+
* (such as to use the MinMax Calibrator recomended for NLP tasks) by calling make_int8_calibrator with the calibrator class
285+
* as a template parameter.
286+
*
287+
* e.g. ``trtorch::ptq::make_int8_calibrator<nvinfer1::IInt8MinMaxCalibrator>(std::move(calibration_dataloader), calibration_cache_file, use_cache);``
288+
* @tparam Algorithm: class nvinfer1::IInt8Calibrator (Default: nvinfer1::IInt8EntropyCalibrator2) - Algorithm to use
289+
* @tparam DataLoader: std::unique_ptr<torch::data::DataLoader> - DataLoader type
290+
* @param dataloader: std::unique_ptr<torch::data::DataLoader> - DataLoader containing data
291+
* @param cache_file_path: const std::string& - Path to read/write calibration cache
292+
* @param use_cache: bool - use calibration cache
293+
* @return Int8Calibrator<Algorithm, DataLoader>
294+
*/
295+
296+
template<typename Algorithm = nvinfer1::IInt8EntropyCalibrator2, typename DataLoader>
297+
TRTORCH_API inline Int8Calibrator<Algorithm, DataLoader> make_int8_calibrator(DataLoader dataloader, const std::string& cache_file_path, bool use_cache) {
298+
return Int8Calibrator<Algorithm, DataLoader>(std::move(dataloader), cache_file_path, use_cache);
299+
}
300+
301+
/**
302+
* @brief A factory to build a post training quantization calibrator from a torch dataloader that only uses the calibration cache
303+
*
304+
* Creates a calibrator to use for post training quantization which reads from a previously created calibration cache, therefore
305+
* you can have a calibration cache generating program that requires a dataloader and a dataset, then save the cache to use later
306+
* in a different program that needs to calibrate from scratch and not have the dataset dependency. However, the network should also
307+
* be recalibrated if its structure changes, or the input data set changes, and it is the responsibility of the application to ensure this.
308+
*
309+
* By default the returned calibrator uses TensorRT Entropy v2 algorithm to perform calibration. This is recommended for feed forward networks
310+
* You can override the algorithm selection (such as to use the MinMax Calibrator recomended for NLP tasks) by calling make_int8_calibrator with
311+
* the calibrator class as a template parameter.
312+
*
313+
* e.g. trtorch::ptq::make_int8_cache_calibrator<nvinfer1::IInt8MinMaxCalibrator>(calibration_cache_file);
314+
* @tparam Algorithm: class nvinfer1::IInt8Calibrator (Default: nvinfer1::IInt8EntropyCalibrator2) - Algorithm to use
315+
* @param cache_file_path: const std::string& - Path to read/write calibration cache
316+
* @return Int8CacheCalibrator<Algorithm>
317+
*/
318+
template<typename Algorithm = nvinfer1::IInt8EntropyCalibrator2>
319+
TRTORCH_API inline Int8CacheCalibrator<Algorithm> make_int8_cache_calibrator(const std::string& cache_file_path) {
320+
return Int8CacheCalibrator<Algorithm>(cache_file_path);
321+
}
322+
272323
} // namespace ptq
273324
} // namespace trtorch

cpp/api/include/trtorch/trtorch.h

Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,11 @@ class ArrayRef;
2929
}
3030

3131
namespace nvinfer1 {
32-
class IInt8EntropyCalibrator2;
32+
class IInt8Calibrator;
3333
}
3434
#endif //DOXYGEN_SHOULD_SKIP_THIS
3535

3636
#include "trtorch/macros.h"
37-
#include "trtorch/logging.h"
38-
#include "trtorch/ptq.h"
3937
namespace trtorch {
4038
/**
4139
* Settings data structure for TRTorch compilation
@@ -406,50 +404,4 @@ TRTORCH_API torch::jit::Module CompileGraph(const torch::jit::Module& module, Ex
406404
* @return: std::string: Serialized TensorRT engine equivilant to the method graph
407405
*/
408406
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::Module& module, std::string method_name, ExtraInfo info);
409-
410-
namespace ptq {
411-
/**
412-
* @brief A factory to build a post training quantization calibrator from a torch dataloader
413-
*
414-
* Creates a calibrator to use for post training quantization. By default the returned calibrator uses TensorRT Entropy v2
415-
* algorithm to perform calibration. This is recommended for feed forward networks. You can override the algorithm selection
416-
* (such as to use the MinMax Calibrator recomended for NLP tasks) by calling make_int8_calibrator with the calibrator class
417-
* as a template parameter.
418-
*
419-
* e.g. ``trtorch::ptq::make_int8_calibrator<nvinfer1::IInt8MinMaxCalibrator>(std::move(calibration_dataloader), calibration_cache_file, use_cache);``
420-
* @tparam Algorithm: class nvinfer1::IInt8Calibrator (Default: nvinfer1::IInt8EntropyCalibrator2) - Algorithm to use
421-
* @tparam DataLoader: std::unique_ptr<torch::data::DataLoader> - DataLoader type
422-
* @param dataloader: std::unique_ptr<torch::data::DataLoader> - DataLoader containing data
423-
* @param cache_file_path: const std::string& - Path to read/write calibration cache
424-
* @param use_cache: bool - use calibration cache
425-
* @return Int8Calibrator<Algorithm, DataLoader>
426-
*/
427-
428-
template<typename Algorithm = nvinfer1::IInt8EntropyCalibrator2, typename DataLoader>
429-
TRTORCH_API inline Int8Calibrator<Algorithm, DataLoader> make_int8_calibrator(DataLoader dataloader, const std::string& cache_file_path, bool use_cache) {
430-
return Int8Calibrator<Algorithm, DataLoader>(std::move(dataloader), cache_file_path, use_cache);
431-
}
432-
433-
/**
434-
* @brief A factory to build a post training quantization calibrator from a torch dataloader that only uses the calibration cache
435-
*
436-
* Creates a calibrator to use for post training quantization which reads from a previously created calibration cache, therefore
437-
* you can have a calibration cache generating program that requires a dataloader and a dataset, then save the cache to use later
438-
* in a different program that needs to calibrate from scratch and not have the dataset dependency. However, the network should also
439-
* be recalibrated if its structure changes, or the input data set changes, and it is the responsibility of the application to ensure this.
440-
*
441-
* By default the returned calibrator uses TensorRT Entropy v2 algorithm to perform calibration. This is recommended for feed forward networks
442-
* You can override the algorithm selection (such as to use the MinMax Calibrator recomended for NLP tasks) by calling make_int8_calibrator with
443-
* the calibrator class as a template parameter.
444-
*
445-
* e.g. trtorch::ptq::make_int8_cache_calibrator<nvinfer1::IInt8MinMaxCalibrator>(calibration_cache_file);
446-
* @tparam Algorithm: class nvinfer1::IInt8Calibrator (Default: nvinfer1::IInt8EntropyCalibrator2) - Algorithm to use
447-
* @param cache_file_path: const std::string& - Path to read/write calibration cache
448-
* @return Int8CacheCalibrator<Algorithm>
449-
*/
450-
template<typename Algorithm = nvinfer1::IInt8EntropyCalibrator2>
451-
TRTORCH_API inline Int8CacheCalibrator<Algorithm> make_int8_cache_calibrator(const std::string& cache_file_path) {
452-
return Int8CacheCalibrator<Algorithm>(cache_file_path);
453-
}
454-
} // namespace ptq
455407
} // namespace trtorch

cpp/ptq/main.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "torch/script.h"
22
#include "torch/torch.h"
33
#include "trtorch/trtorch.h"
4+
#include "trtorch/ptq.h"
45

56
#include "NvInfer.h"
67

docs/._index.html

0 Bytes
Binary file not shown.

docs/_cpp_api/class_view_hierarchy.html

Lines changed: 67 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -215,27 +215,13 @@
215215
</a>
216216
<ul class="md-nav__list">
217217
<li class="md-nav__item">
218-
<a class="md-nav__link" href="../tutorials/installation.html#dependencies">
219-
Dependencies
218+
<a class="md-nav__link" href="../tutorials/installation.html#precompiled-binaries">
219+
Precompiled Binaries
220220
</a>
221221
</li>
222222
<li class="md-nav__item">
223-
<a class="md-nav__link" href="../tutorials/installation.html#dependencies-for-compilation">
224-
Dependencies for Compilation
225-
</a>
226-
</li>
227-
<li class="md-nav__item">
228-
<a class="md-nav__link" href="../tutorials/installation.html#building-using-cudnn-tensorrt-tarball-distributions">
229-
<strong>
230-
Building using cuDNN &amp; TensorRT tarball distributions
231-
</strong>
232-
</a>
233-
</li>
234-
<li class="md-nav__item">
235-
<a class="md-nav__link" href="../tutorials/installation.html#building-using-locally-installed-cudnn-tensorrt">
236-
<strong>
237-
Building using locally installed cuDNN &amp; TensorRT
238-
</strong>
223+
<a class="md-nav__link" href="../tutorials/installation.html#compiling-from-source">
224+
Compiling From Source
239225
</a>
240226
</li>
241227
</ul>
@@ -299,6 +285,69 @@
299285
trtorchc
300286
</a>
301287
</li>
288+
<li class="md-nav__item">
289+
<span class="md-nav__link caption">
290+
<span class="caption-text">
291+
Python API Documenation
292+
</span>
293+
</span>
294+
</li>
295+
<li class="md-nav__item">
296+
<a class="md-nav__link" href="../py_api/trtorch.html">
297+
trtorch
298+
</a>
299+
<ul class="md-nav__list">
300+
<li class="md-nav__item">
301+
<a class="md-nav__link" href="../py_api/trtorch.html#functions">
302+
Functions
303+
</a>
304+
</li>
305+
<li class="md-nav__item">
306+
<a class="md-nav__link" href="../py_api/trtorch.html#enums">
307+
Enums
308+
</a>
309+
</li>
310+
<li class="md-nav__item">
311+
<a class="md-nav__link" href="../py_api/trtorch.html#submodules">
312+
Submodules
313+
</a>
314+
</li>
315+
</ul>
316+
</li>
317+
<li class="md-nav__item">
318+
<a class="md-nav__link" href="../py_api/logging.html">
319+
trtorch.logging
320+
</a>
321+
</li>
322+
<li class="md-nav__item">
323+
<span class="md-nav__link caption">
324+
<span class="caption-text">
325+
C++ API Documenation
326+
</span>
327+
</span>
328+
</li>
329+
<li class="md-nav__item">
330+
<a class="md-nav__link" href="trtorch_cpp.html">
331+
TRTorch C++ API
332+
</a>
333+
<ul class="md-nav__list">
334+
<li class="md-nav__item">
335+
<a class="md-nav__link" href="trtorch_cpp.html#class-hierarchy">
336+
Class Hierarchy
337+
</a>
338+
</li>
339+
<li class="md-nav__item">
340+
<a class="md-nav__link" href="trtorch_cpp.html#file-hierarchy">
341+
File Hierarchy
342+
</a>
343+
</li>
344+
<li class="md-nav__item">
345+
<a class="md-nav__link" href="trtorch_cpp.html#full-api">
346+
Full API
347+
</a>
348+
</li>
349+
</ul>
350+
</li>
302351
<li class="md-nav__item">
303352
<span class="md-nav__link caption">
304353
<span class="caption-text">
@@ -397,69 +446,6 @@
397446
</li>
398447
</ul>
399448
</li>
400-
<li class="md-nav__item">
401-
<span class="md-nav__link caption">
402-
<span class="caption-text">
403-
Python API Documenation
404-
</span>
405-
</span>
406-
</li>
407-
<li class="md-nav__item">
408-
<a class="md-nav__link" href="../py_api/trtorch.html">
409-
trtorch
410-
</a>
411-
<ul class="md-nav__list">
412-
<li class="md-nav__item">
413-
<a class="md-nav__link" href="../py_api/trtorch.html#functions">
414-
Functions
415-
</a>
416-
</li>
417-
<li class="md-nav__item">
418-
<a class="md-nav__link" href="../py_api/trtorch.html#enums">
419-
Enums
420-
</a>
421-
</li>
422-
<li class="md-nav__item">
423-
<a class="md-nav__link" href="../py_api/trtorch.html#submodules">
424-
Submodules
425-
</a>
426-
</li>
427-
</ul>
428-
</li>
429-
<li class="md-nav__item">
430-
<a class="md-nav__link" href="../py_api/logging.html">
431-
trtorch.logging
432-
</a>
433-
</li>
434-
<li class="md-nav__item">
435-
<span class="md-nav__link caption">
436-
<span class="caption-text">
437-
C++ API Documenation
438-
</span>
439-
</span>
440-
</li>
441-
<li class="md-nav__item">
442-
<a class="md-nav__link" href="trtorch_cpp.html">
443-
TRTorch C++ API
444-
</a>
445-
<ul class="md-nav__list">
446-
<li class="md-nav__item">
447-
<a class="md-nav__link" href="trtorch_cpp.html#class-hierarchy">
448-
Class Hierarchy
449-
</a>
450-
</li>
451-
<li class="md-nav__item">
452-
<a class="md-nav__link" href="trtorch_cpp.html#file-hierarchy">
453-
File Hierarchy
454-
</a>
455-
</li>
456-
<li class="md-nav__item">
457-
<a class="md-nav__link" href="trtorch_cpp.html#full-api">
458-
Full API
459-
</a>
460-
</li>
461-
</ul>
462-
</li>
463449
</ul>
464450
</nav>
465451
</div>

0 commit comments

Comments
 (0)