Skip to content

Commit e94c64a

Browse files
committed
fix: Device casting issues with certain aten operators
- Investigated issue arising with BART-base model (https://huggingface.co/facebook/bart-base) where certain tensor inputs to TensorRT were on the cpu, despite users explicitly casting all inputs properly - Traced issue to internally-generated 0D tensors, mask tensors, and operations returning CPU tensors passed between Torch and Torch-TensorRT engines - Added lowering passes to ensure function edge cases are appropriately dealt with, and added validation check in runtime to avoid models crashing at runtime due to device mismatches
1 parent ce29cc7 commit e94c64a

File tree

6 files changed

+150
-1
lines changed

6 files changed

+150
-1
lines changed

core/lowering/lowering.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
7070
passes::SiluToSigmoidMultipication(g);
7171
passes::RemoveSingleUse0DTensors(g);
7272
passes::RemoveUnnecessaryCasts(g);
73+
passes::UnpackAndCastMaskedFill(g);
74+
passes::UnpackAndCastNumToTensor(g);
75+
passes::UnpackAndCastFull(g);
76+
passes::ReplaceScalarImplicit(g);
7377
LOG_GRAPH(*g);
7478
}
7579

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ cc_library(
1414
name = "passes",
1515
srcs = [
1616
"convNd_to_convolution.cpp",
17+
"device_casting.cpp",
1718
"exception_elimination.cpp",
1819
"fuse_addmm_branches.cpp",
1920
"linear_to_addmm.cpp",

core/lowering/passes/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
target_sources(${lib_name}
22
PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/convNd_to_convolution.cpp"
3+
"${CMAKE_CURRENT_SOURCE_DIR}/device_casting.cpp"
34
"${CMAKE_CURRENT_SOURCE_DIR}/exception_elimination.cpp"
45
"${CMAKE_CURRENT_SOURCE_DIR}/fuse_addmm_branches.cpp"
56
"${CMAKE_CURRENT_SOURCE_DIR}/linear_to_addmm.cpp"
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#include "torch/csrc/jit/ir/constants.h"
2+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
3+
4+
#include "core/util/prelude.h"
5+
6+
namespace torch_tensorrt {
7+
namespace core {
8+
namespace lowering {
9+
namespace passes {
10+
11+
void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph) {
12+
std::string masked_fill_pattern = R"IR(
13+
graph(%self, %mask, %value):
14+
%out: Tensor = aten::masked_fill_(%self, %mask, %value)
15+
return (%out))IR";
16+
17+
// Calls to masked_fill_ often utilize CPU tensors, and as such
18+
// should be casted to CUDA to avoid device mismatch errors
19+
std::string unpacked_pattern = R"IR(
20+
graph(%self, %mask, %value):
21+
%device: Device = prim::Constant[value="cuda"]()
22+
%dtype: NoneType = prim::Constant()
23+
%false: bool = prim::Constant[value=0]()
24+
%mask_cuda: Tensor = aten::to(%mask, %device, %dtype, %false, %false)
25+
%self_cuda: Tensor = aten::to(%self, %device, %dtype, %false, %false)
26+
27+
# Value is cast to type of original tensor and value defaults to float
28+
%is_float: bool = aten::is_floating_point(%self)
29+
%out: Tensor = prim::If(%is_float)
30+
block0():
31+
%no_cast: Tensor = aten::masked_fill(%self_cuda, %mask_cuda, %value)
32+
-> (%no_cast)
33+
block1():
34+
%value_int: int = aten::Int(%value)
35+
%casted_int: Tensor = aten::masked_fill(%self_cuda, %mask_cuda, %value_int)
36+
-> (%casted_int)
37+
38+
return (%out))IR";
39+
40+
torch::jit::SubgraphRewriter masked_fill_rewriter;
41+
masked_fill_rewriter.RegisterRewritePattern(masked_fill_pattern, unpacked_pattern);
42+
masked_fill_rewriter.runOnGraph(graph);
43+
LOG_GRAPH("After unpack and cast masked_fill_: " << *graph);
44+
}
45+
46+
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph) {
47+
std::string num_to_tensor_cast_pattern = R"IR(
48+
graph(%1: int):
49+
%2: Tensor = prim::NumToTensor(%1)
50+
return (%2))IR";
51+
52+
// 0D Tensors are initialized on cpu, and need to be casted to CUDA
53+
// to avoid device mismatch issues
54+
std::string num_to_tensor_clean_pattern = R"IR(
55+
graph(%1: int):
56+
%2: Tensor = prim::NumToTensor(%1)
57+
%device: Device = prim::Constant[value="cuda"]()
58+
%dtype: NoneType = prim::Constant()
59+
%false: bool = prim::Constant[value=0]()
60+
%3: Tensor = aten::to(%2, %device, %dtype, %false, %false)
61+
return (%3))IR";
62+
63+
torch::jit::SubgraphRewriter num_to_tensor_cast_rewriter;
64+
num_to_tensor_cast_rewriter.RegisterRewritePattern(num_to_tensor_cast_pattern, num_to_tensor_clean_pattern);
65+
num_to_tensor_cast_rewriter.runOnGraph(graph);
66+
67+
LOG_GRAPH("After unpack and cast NumToTensor: " << *graph);
68+
}
69+
70+
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph) {
71+
std::string full_cast_pattern = R"IR(
72+
graph(%1, %2, %3, %4, %5, %6):
73+
%out: Tensor = aten::full(%1, %2, %3, %4, %5, %6)
74+
return (%out))IR";
75+
76+
// Tensors created via aten::full are initialized on cpu, and need to be casted to CUDA
77+
// to avoid device mismatch issues
78+
std::string full_clean_pattern = R"IR(
79+
graph(%1, %2, %3, %4, %5, %6):
80+
%cuda: Device = prim::Constant[value="cuda"]()
81+
%out: Tensor = aten::full(%1, %2, %3, %4, %cuda, %6)
82+
return (%out))IR";
83+
84+
torch::jit::SubgraphRewriter full_cast_rewriter;
85+
full_cast_rewriter.RegisterRewritePattern(full_cast_pattern, full_clean_pattern);
86+
full_cast_rewriter.runOnGraph(graph);
87+
88+
LOG_GRAPH("After unpack and cast full: " << *graph);
89+
}
90+
91+
void ReplaceScalarImplicit(std::shared_ptr<torch::jit::Graph>& graph) {
92+
std::string scalar_implicit_cast_pattern = R"IR(
93+
graph(%1: Tensor):
94+
%2: Scalar = aten::ScalarImplicit(%1)
95+
return (%2))IR";
96+
97+
// ScalarImplicit can only unpack 0D tensors, whereas Tensors operated on by
98+
// TensorRT are padded to 1 dimension. aten::item() resolves this conflict
99+
std::string scalar_implicit_clean_pattern = R"IR(
100+
graph(%1: Tensor):
101+
%2: Scalar = aten::item(%1)
102+
return (%2))IR";
103+
104+
torch::jit::SubgraphRewriter scalar_implicit_cast_rewriter;
105+
scalar_implicit_cast_rewriter.RegisterRewritePattern(scalar_implicit_cast_pattern, scalar_implicit_clean_pattern);
106+
scalar_implicit_cast_rewriter.runOnGraph(graph);
107+
108+
LOG_GRAPH("After unpack and cast full: " << *graph);
109+
}
110+
111+
} // namespace passes
112+
} // namespace lowering
113+
} // namespace core
114+
} // namespace torch_tensorrt

core/lowering/passes/passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);
4040
void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph);
4141
void UnpackHardSwish(std::shared_ptr<torch::jit::Graph>& graph);
4242
void UnpackHardSigmoid(std::shared_ptr<torch::jit::Graph>& graph);
43+
void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph);
44+
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph);
45+
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph);
46+
void ReplaceScalarImplicit(std::shared_ptr<torch::jit::Graph>& graph);
4347

4448
} // namespace passes
4549
} // namespace lowering

core/runtime/execute_engine.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,41 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
6363
CudaDevice curr_device = get_current_device();
6464
LOG_DEBUG("Current Device: " << curr_device);
6565

66+
// Generic Target Device Prefix
67+
std::string target_device = "cuda:";
68+
6669
if (is_switch_required(curr_device, compiled_engine->device_info)) {
6770
// Scan through available CUDA devices and set the CUDA device context correctly
6871
CudaDevice device = select_cuda_device(compiled_engine->device_info);
6972
set_cuda_device(device);
7073

71-
std::string target_device = "cuda:" + std::to_string(device.id);
74+
// Target device is new device
75+
target_device += std::to_string(device.id);
7276

7377
for (auto& in : inputs) {
7478
in = in.to(torch::Device(target_device));
7579
}
80+
} else {
81+
// Target device is current device
82+
target_device += std::to_string(curr_device.id);
83+
84+
// For each input, ensure its current device is the desired target device
85+
for (size_t i = 0; i < inputs.size(); i++) {
86+
at::Tensor* in = &inputs[i];
87+
std::string current_tensor_device = in->device().str();
88+
89+
// If current device string does not match target device, display warning and move tensor accordingly
90+
if (current_tensor_device != target_device) {
91+
LOG_WARNING(
92+
"Input " << i << " of engine " << compiled_engine->name << " was found to be on " << current_tensor_device
93+
<< " but should be on " << target_device
94+
<< ". This tensor is being moved manually by the runtime but "
95+
<< "for performance considerations, ensure your inputs are all on GPU "
96+
<< "and open an issue here (https://github.com/pytorch/TensorRT/issues) if this "
97+
<< "warning persists.");
98+
*in = in->to(torch::Device(target_device));
99+
}
100+
}
76101
}
77102

78103
std::vector<void*> gpu_handles;

0 commit comments

Comments
 (0)