Skip to content

Commit a165811

Browse files
committed
fix: fix typo
Signed-off-by: Bo Wang <[email protected]>
1 parent c3082f5 commit a165811

File tree

2 files changed

+6
-19
lines changed

2 files changed

+6
-19
lines changed

core/partitioning/partitioning.cpp

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
1-
2-
3-
4-
1+
#include "core/partitioning/partitioning.h"
52
#include <queue>
6-
7-
#include "torch/csrc/jit/passes/constant_pooling.h"
8-
#include "torch/csrc/jit/passes/dead_code_elimination.h"
9-
103
#include "core/conversion/conversion.h"
114
#include "core/conversion/evaluators/evaluators.h"
12-
#include "core/partitioning/partitioning.h"
135
#include "core/partitioning/partitioningctx/PartitioningCtx.h"
6+
#include "torch/csrc/jit/passes/constant_pooling.h"
7+
#include "torch/csrc/jit/passes/dead_code_elimination.h"
148

159
namespace torch_tensorrt {
1610
namespace core {
@@ -35,8 +29,6 @@ bool containNonTensorOutputs(torch::jit::Node* n) {
3529
return false;
3630
}
3731

38-
39-
4032
// Check if the inputs and outputs of the graph are Tensor. If not, then fallback connected nodes
4133
void SetInputsOutputsConnectedNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
4234
// fallback nodes that produce entire graph's nonTensor output
@@ -91,7 +83,7 @@ void SetNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
9183
// initial_fallback_nodes are the fallback nodes that we have before we run BFS in this function
9284
std::queue<torch::jit::Node*> q;
9385
for (auto& node : initial_fallback_nodes) {
94-
q.push(node.first);
86+
q.push(node);
9587
}
9688

9789
while (!q.empty()) {
@@ -110,8 +102,7 @@ void SetNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
110102
if (!isTensor(output)) {
111103
for (auto use : output->uses()) {
112104
auto node = use.user;
113-
if (node->kind() != torch::jit::prim::Constant &&
114-
ctx->shouldNodeRunInTensorRT(node)) {
105+
if (node->kind() != torch::jit::prim::Constant && ctx->shouldNodeRunInTensorRT(node)) {
115106
ctx->setNodeExecutorDecision(node, NodeExecutorDecision::kNON_TENSOR);
116107
q.push(node);
117108
}
@@ -147,7 +138,6 @@ std::vector<torch::jit::Node*> TraverseNodesForMinBlockSize(PartitioningCtx* ctx
147138
return min_block_fallback_nodes;
148139
}
149140

150-
151141
// Set the nodes that fallback because of min_block_size
152142
void SetMinBlockFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
153143
// first traverse all the nodes to find the initial nodes that don't meet the min_block_size requirement
@@ -328,7 +318,6 @@ bool checkLoopEvaluatable(torch::jit::Node* n) {
328318
return compile_to_trt;
329319
}
330320

331-
332321
void SegmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
333322
auto nodes = block->nodes();
334323

@@ -429,8 +418,6 @@ void SetNodeExecutorDecision(PartitioningCtx* ctx, torch::jit::Block* block) {
429418
SetMinBlockFallbackNodes(ctx, block);
430419
}
431420

432-
433-
434421
PartitionedGraph Partition(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& example_tensor_map) {
435422
LOG_DEBUG(ctx->settings);
436423

core/partitioning/partitioningctx/PartitioningCtx.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ bool PartitioningCtx::isNodeExecutorKnown(torch::jit::Node* n) {
9292
}
9393
}
9494

95-
std::vector<torch::jit::Node*> PartitionCtx::getNodesRunInTorch() {
95+
std::vector<torch::jit::Node*> PartitioningCtx::getNodesRunInTorch() {
9696
std::vector<torch::jit::Node*> nodes_run_in_torch;
9797
for (auto i : node_executor_decision_map) {
9898
if (i.second == NodeExecutorDecision::kCONVERT) {

0 commit comments

Comments
 (0)