Skip to content

Commit 34366b3

Browse files
committed
fix: change the shouldNodeRunInTorch logic
Signed-off-by: Bo Wang <[email protected]>
1 parent 77f4c09 commit 34366b3

File tree

3 files changed

+21
-10
lines changed

3 files changed

+21
-10
lines changed

core/partitioning/partitioning.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#include <queue>
33
#include "core/conversion/conversion.h"
44
#include "core/conversion/evaluators/evaluators.h"
5-
#include "core/partitioning/partitioningctx/PartitioningCtx.h"
65
#include "torch/csrc/jit/passes/constant_pooling.h"
76
#include "torch/csrc/jit/passes/dead_code_elimination.h"
87

@@ -102,8 +101,7 @@ void SetNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
102101
if (!isTensor(output)) {
103102
for (auto use : output->uses()) {
104103
auto node = use.user;
105-
if (node->kind() != torch::jit::prim::Constant && node->kind() != torch::jit::prim::Return &&
106-
ctx->shouldNodeRunInTensorRT(node)) {
104+
if (node->kind() != torch::jit::prim::Constant && ctx->shouldNodeRunInTensorRT(node)) {
107105
ctx->setNodeExecutorDecision(node, NodeExecutorDecision::kNON_TENSOR);
108106
q.push(node);
109107
}
@@ -454,7 +452,6 @@ void Partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map) {
454452
LOG_DEBUG("Registering input/output torch::jit::Value for segmented graphs");
455453
RegisterSegmentsOutputs(ctx, block);
456454

457-
458455
// run shape analysis on each segmented block
459456
RunShapeAnalysis(ctx, block, example_tensor_map);
460457
}

core/partitioning/partitioning.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
#include "core/ir/ir.h"
99
#include "core/partitioning/partitioningctx/PartitioningCtx.h"
10-
#include "core/partitioning/partitioninginfo/PartitioningInfo.h"
11-
#include "core/partitioning/segmentedblock/SegmentedBlock.h"
1210
#include "core/util/prelude.h"
1311

1412
namespace torch_tensorrt {

core/partitioning/partitioningctx/PartitioningCtx.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,30 @@ void PartitioningCtx::setNodeExecutorDecision(torch::jit::Node* n, NodeExecutorD
4343

4444
bool PartitioningCtx::shouldNodeRunInTorch(torch::jit::Node* n) {
4545
auto iter = node_executor_decision_map.find(n);
46-
if (iter == node_executor_decision_map.end()) {
47-
LOG_ERROR("No info about node " << *n << " execution decision status.");
46+
auto decision = NodeExecutorDecision::kUNKNOWN;
47+
48+
if (iter != node_executor_decision_map.end()) {
49+
decision = iter->second;
50+
}
51+
if (decision == NodeExecutorDecision::kCONVERT || decision == NodeExecutorDecision::kUNKNOWN) {
52+
return false;
53+
} else {
54+
return true;
4855
}
49-
return iter->second != NodeExecutorDecision::kCONVERT;
5056
}
5157

5258
bool PartitioningCtx::shouldNodeRunInTensorRT(torch::jit::Node* n) {
53-
return !shouldNodeRunInTorch(n);
59+
auto iter = node_executor_decision_map.find(n);
60+
auto decision = NodeExecutorDecision::kUNKNOWN;
61+
if (iter != node_executor_decision_map.end()) {
62+
decision = iter->second;
63+
}
64+
65+
if (decision == NodeExecutorDecision::kCONVERT) {
66+
return true;
67+
} else {
68+
return false;
69+
}
5470
}
5571

5672
std::vector<torch::jit::Node*> PartitioningCtx::getNodesRunInTorch() {

0 commit comments

Comments
 (0)