Skip to content

Commit 6c03afc

Browse files
authored
[TensorRT] Optimize performance via only place TRT ops to gpu device. (#860)
Signed-off-by: Tao Peng <[email protected]>
1 parent 8afcfff commit 6c03afc

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

tensorflow/compiler/tf2tensorrt/tool/tf2trt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def GetConversionParams(run_params):
6363
rewriter_config_template=None,
6464
max_workspace_size_bytes=1 << 25,
6565
precision_mode=run_params['precision_mode'],
66-
minimum_segment_size=2,
66+
minimum_segment_size=run_params['minimum_segment_size'],
6767
is_dynamic_op=run_params['dynamic_engine'],
6868
maximum_cached_engines=1,
6969
use_calibration=run_params['use_calibration'],
@@ -96,6 +96,7 @@ def ConvertGraph(run_params, saved_model_dir, trt_saved_model_dir):
9696
'use_calibration': False,
9797
'max_batch_size': 1024,
9898
'convert_online': False,
99+
'minimum_segment_size': 4,
99100
'use_ev': True,
100101
}
101102

tensorflow/core/common_runtime/placer.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ limitations under the License.
2828
#include "tensorflow/core/framework/types.h"
2929
#include "tensorflow/core/framework/types.pb.h"
3030
#include "tensorflow/core/lib/core/errors.h"
31+
#include "tensorflow/core/util/env_var.h"
3132
#include "tensorflow/core/util/dump_graph.h"
3233
#include "tensorflow/core/util/port.h"
3334

@@ -230,6 +231,29 @@ Status Placer::Run() {
230231
colocation_graph.AssignGpuStreamIdx(node);
231232
}
232233

234+
bool place_trtop_on_gpu_only = false;
235+
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
236+
"PLACE_TRT_OP_ON_GPU_ONLY", false, &place_trtop_on_gpu_only));
237+
// Keep TRTEngineOp On GPU Only
238+
if (place_trtop_on_gpu_only) {
239+
std::string cpu_name, gpu_name;
240+
for (auto d : devices_->devices()) {
241+
if (d->name().find("device:CPU:") != std::string::npos) {
242+
cpu_name = d->name();
243+
} else if (d->name().find("device:GPU:") != std::string::npos) {
244+
gpu_name = d->name();
245+
}
246+
}
247+
248+
for (Node* n : graph_->op_nodes()) {
249+
if (n->type_string() == "TRTEngineOp") {
250+
n->set_assigned_device_name(gpu_name);
251+
} else {
252+
n->set_assigned_device_name(cpu_name);
253+
}
254+
}
255+
}
256+
233257
if (VLOG_IS_ON(3)) {
234258
DumpGraphToFile("placer_output", *graph_, nullptr);
235259
}

0 commit comments

Comments
 (0)