Skip to content

Commit 3d84b43

Browse files
authored
fix: Device casting issues with certain aten operators (#1416)
* fix: Device casting issues with certain `aten` operators - Investigated issue arising with BART-base model (https://huggingface.co/facebook/bart-base) where certain tensor inputs to TensorRT were on the cpu, despite users explicitly casting all inputs properly - Traced issue to internally-generated 0D tensors, mask tensors, and operations returning CPU tensors passed between Torch and Torch-TensorRT engines - Added lowering passes to ensure function edge cases are appropriately dealt with, tensors are located on the proper device at runtime, and added validation check in runtime to avoid models crashing at runtime due to device mismatches - Added testing for lowering passes to ensure output values are accurate * fix: Update paradigm for device casting to depend on user-specified device - Adde field to LowerInfo to hold device information - Update internal Device struct location to allow streamlined imports - Update BUILD files - Build strings in lowering phase using user-specified target device - Update CMakeLists to reflect IR dependency in lowering - Update runtime device location code to run regardless of whether a switch is required or not.
1 parent 86ff042 commit 3d84b43

File tree

15 files changed

+382
-11
lines changed

15 files changed

+382
-11
lines changed

core/conversion/conversionctx/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ cc_library(
2121
deps = [
2222
"@tensorrt//:nvinfer",
2323
"//core/util:prelude",
24+
"//core/ir",
2425
] + select({
2526
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
2627
"//conditions:default": ["@libtorch//:libtorch"],

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,21 @@
99
#include "torch/csrc/jit/ir/ir.h"
1010

1111
#include <cuda_runtime.h>
12+
#include "core/ir/ir.h"
1213
#include "core/util/prelude.h"
1314

1415
namespace torch_tensorrt {
1516
namespace core {
1617
namespace conversion {
1718

18-
struct Device {
19-
nvinfer1::DeviceType device_type;
20-
int64_t gpu_id;
21-
int64_t dla_core;
22-
bool allow_gpu_fallback;
23-
Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
24-
};
25-
2619
struct BuilderSettings {
2720
std::set<nvinfer1::DataType> enabled_precisions = {};
2821
bool sparse_weights = false;
2922
bool disable_tf32 = false;
3023
bool refit = false;
3124
bool debug = false;
3225
bool truncate_long_and_double = false;
33-
Device device;
26+
ir::Device device;
3427
nvinfer1::EngineCapability capability = TRT_ENGINE_CAPABILITY_STANDARD;
3528
nvinfer1::IInt8Calibrator* calibrator = nullptr;
3629
uint64_t num_avg_timing_iters = 1;

core/ir/ir.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@ namespace torch_tensorrt {
1111
namespace core {
1212
namespace ir {
1313

14+
struct Device {
15+
nvinfer1::DeviceType device_type;
16+
int64_t gpu_id;
17+
int64_t dla_core;
18+
bool allow_gpu_fallback;
19+
Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
20+
};
21+
1422
struct Input : torch::CustomClassHolder {
1523
Input(){};
1624
Input(

core/lowering/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ cc_library(
2424
deps = [
2525
"//core/lowering/passes",
2626
"//core/util:prelude",
27+
"//core/ir",
2728
] + select({
2829
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
2930
"//conditions:default": ["@libtorch//:libtorch"],

core/lowering/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ set(HEADER_FILES
1515
target_sources(${lib_name}
1616
PRIVATE
1717
${CXX_SRCS}
18+
PUBLIC
19+
$<TARGET_OBJECTS:core_ir>
1820
$<TARGET_OBJECTS:core_util>
1921
)
2022

@@ -25,8 +27,9 @@ target_include_directories(${lib_name}
2527

2628
target_link_libraries(${lib_name}
2729
PUBLIC
30+
TensorRT::nvinfer
2831
torch
29-
PRIVATE
32+
core_ir
3033
core_util
3134
)
3235

core/lowering/lowering.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
7070
passes::SiluToSigmoidMultipication(g);
7171
passes::RemoveSingleUse0DTensors(g);
7272
passes::RemoveUnnecessaryCasts(g);
73+
passes::UnpackAndCastMaskedFill(g, lower_info.getGPUDeviceString());
74+
passes::UnpackAndCastNumToTensor(g, lower_info.getGPUDeviceString());
75+
passes::UnpackAndCastFull(g, lower_info.getGPUDeviceString());
76+
passes::ReplaceScalarImplicit(g);
7377
passes::RewriteInputsWithParams(g, params);
7478
LOG_GRAPH(*g);
7579
}

core/lowering/lowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22
#include <memory>
3+
#include "core/ir/ir.h"
34
#include "torch/csrc/jit/ir/ir.h"
45

56
namespace torch_tensorrt {
@@ -15,8 +16,13 @@ struct LowerInfo {
1516
// Since these QDQ nodes will be identical as they share same input, one of them is eliminated due to CSE lowering
1617
// pass. Disable this in order to not disturb TensorRT's QAT optimizations.
1718
bool disable_cse = false;
19+
ir::Device target_device;
1820
std::vector<std::string> forced_fallback_modules;
1921
friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l);
22+
23+
std::string getGPUDeviceString() {
24+
return "cuda:" + std::to_string(target_device.gpu_id);
25+
};
2026
};
2127

2228
void LowerBlock(torch::jit::Block* b);

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ cc_library(
1414
name = "passes",
1515
srcs = [
1616
"convNd_to_convolution.cpp",
17+
"device_casting.cpp",
1718
"exception_elimination.cpp",
1819
"fuse_addmm_branches.cpp",
1920
"linear_to_addmm.cpp",

core/lowering/passes/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
target_sources(${lib_name}
22
PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/convNd_to_convolution.cpp"
3+
"${CMAKE_CURRENT_SOURCE_DIR}/device_casting.cpp"
34
"${CMAKE_CURRENT_SOURCE_DIR}/exception_elimination.cpp"
45
"${CMAKE_CURRENT_SOURCE_DIR}/fuse_addmm_branches.cpp"
56
"${CMAKE_CURRENT_SOURCE_DIR}/linear_to_addmm.cpp"
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
#include "torch/csrc/jit/ir/constants.h"
2+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
3+
4+
#include "core/util/prelude.h"
5+
6+
namespace torch_tensorrt {
7+
namespace core {
8+
namespace lowering {
9+
namespace passes {
10+
11+
void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name) {
12+
std::string masked_fill_pattern = R"IR(
13+
graph(%self, %mask, %value):
14+
%out: Tensor = aten::masked_fill_(%self, %mask, %value)
15+
return (%out))IR";
16+
17+
// Calls to masked_fill_ often utilize CPU tensors, and as such
18+
// should be moved to gpu to avoid device mismatch errors
19+
20+
// Separate string into portions to insert device name
21+
std::string clean_pattern_part_1 = R"IR(
22+
graph(%self, %mask, %value):
23+
%device: Device = prim::Constant[value=")IR";
24+
25+
std::string clean_pattern_part_2 = R"IR("]()
26+
%dtype: NoneType = prim::Constant()
27+
%false: bool = prim::Constant[value=0]()
28+
%mask_cuda: Tensor = aten::to(%mask, %device, %dtype, %false, %false)
29+
%self_cuda: Tensor = aten::to(%self, %device, %dtype, %false, %false)
30+
%out: Tensor = aten::masked_fill(%self_cuda, %mask_cuda, %value)
31+
return (%out))IR";
32+
33+
auto unpacked_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;
34+
35+
torch::jit::SubgraphRewriter masked_fill_rewriter;
36+
masked_fill_rewriter.RegisterRewritePattern(masked_fill_pattern, unpacked_pattern);
37+
masked_fill_rewriter.runOnGraph(graph);
38+
LOG_GRAPH("After unpack and cast masked_fill_: " << *graph);
39+
}
40+
41+
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name) {
42+
std::string num_to_tensor_cast_pattern = R"IR(
43+
graph(%1: Scalar):
44+
%2: Tensor = prim::NumToTensor(%1)
45+
return (%2))IR";
46+
47+
// 0D Tensors are initialized on cpu, and need to be moved to gpu
48+
// to avoid device mismatch issues
49+
50+
// Separate string into portions to insert device name
51+
std::string clean_pattern_part_1 = R"IR(
52+
graph(%1: Scalar):
53+
%2: Tensor = prim::NumToTensor(%1)
54+
%device: Device = prim::Constant[value=")IR";
55+
56+
std::string clean_pattern_part_2 = R"IR("]()
57+
%dtype: NoneType = prim::Constant()
58+
%false: bool = prim::Constant[value=0]()
59+
%3: Tensor = aten::to(%2, %device, %dtype, %false, %false)
60+
return (%3))IR";
61+
62+
auto num_to_tensor_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;
63+
64+
torch::jit::SubgraphRewriter num_to_tensor_cast_rewriter;
65+
num_to_tensor_cast_rewriter.RegisterRewritePattern(num_to_tensor_cast_pattern, num_to_tensor_clean_pattern);
66+
num_to_tensor_cast_rewriter.runOnGraph(graph);
67+
68+
LOG_GRAPH("After unpack and cast NumToTensor: " << *graph);
69+
}
70+
71+
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name) {
72+
std::string full_cast_pattern = R"IR(
73+
graph(%1, %2, %3, %4, %5, %6):
74+
%out: Tensor = aten::full(%1, %2, %3, %4, %5, %6)
75+
return (%out))IR";
76+
77+
// Tensors created via aten::full are initialized on cpu, and need to be casted to gpu
78+
// to avoid device mismatch issues
79+
80+
// Separate string into portions to insert device name
81+
std::string clean_pattern_part_1 = R"IR(
82+
graph(%1, %2, %3, %4, %5, %6):
83+
%device: Device = prim::Constant[value=")IR";
84+
85+
std::string clean_pattern_part_2 = R"IR("]()
86+
%out: Tensor = aten::full(%1, %2, %3, %4, %device, %6)
87+
return (%out))IR";
88+
89+
auto full_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;
90+
91+
torch::jit::SubgraphRewriter full_cast_rewriter;
92+
full_cast_rewriter.RegisterRewritePattern(full_cast_pattern, full_clean_pattern);
93+
full_cast_rewriter.runOnGraph(graph);
94+
95+
LOG_GRAPH("After unpack and cast full: " << *graph);
96+
}
97+
98+
void ReplaceScalarImplicit(std::shared_ptr<torch::jit::Graph>& graph) {
99+
std::string scalar_implicit_cast_pattern = R"IR(
100+
graph(%1: Tensor):
101+
%2: Scalar = aten::ScalarImplicit(%1)
102+
return (%2))IR";
103+
104+
// ScalarImplicit can only unpack 0D tensors, whereas Tensors operated on by
105+
// TensorRT are padded to 1 dimension. aten::item() resolves this conflict
106+
std::string scalar_implicit_clean_pattern = R"IR(
107+
graph(%1: Tensor):
108+
%2: Scalar = aten::item(%1)
109+
return (%2))IR";
110+
111+
torch::jit::SubgraphRewriter scalar_implicit_cast_rewriter;
112+
scalar_implicit_cast_rewriter.RegisterRewritePattern(scalar_implicit_cast_pattern, scalar_implicit_clean_pattern);
113+
scalar_implicit_cast_rewriter.runOnGraph(graph);
114+
115+
LOG_GRAPH("After unpack and cast full: " << *graph);
116+
}
117+
118+
} // namespace passes
119+
} // namespace lowering
120+
} // namespace core
121+
} // namespace torch_tensorrt

0 commit comments

Comments
 (0)