Skip to content

Commit 64006cc

Browse files
authored
[DeepRec] Add flag for device_placement_optimization. (#21)
Signed-off-by: Tao Peng <[email protected]>
1 parent af3064a commit 64006cc

File tree

3 files changed

+19
-0
lines changed

3 files changed

+19
-0
lines changed

tensorflow_serving/model_servers/main.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ int main(int argc, char** argv) {
227227
"oss_access_id"),
228228
tensorflow::Flag("oss_access_key", &options.oss_access_key,
229229
"oss_access_key"),
230+
tensorflow::Flag("enable_device_placement_optimization",
231+
&options.enable_device_placement_optimization,
232+
"Use device placement optimization."),
230233
tensorflow::Flag("use_multi_stream", &options.use_multi_stream,
231234
"Use multi-stream or not in session_group")};
232235

tensorflow_serving/model_servers/server.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,13 @@ Status CreatePlatformConfigMap(const Server::Options& server_options,
233233

234234
ParseTimelineConfig(server_options);
235235

236+
if (server_options.enable_device_placement_optimization) {
237+
session_bundle_config.mutable_session_config()
238+
->mutable_graph_options()
239+
->mutable_optimizer_options()
240+
->set_device_placement_optimization(true);
241+
}
242+
236243
session_bundle_config.mutable_session_config()
237244
->mutable_gpu_options()
238245
->set_per_process_gpu_memory_fraction(
@@ -294,6 +301,13 @@ Status CreatePlatformConfigMapV2(const Server::Options& server_options,
294301

295302
ParseTimelineConfig(server_options);
296303

304+
if (server_options.enable_device_placement_optimization) {
305+
model_session_config->mutable_session_config()
306+
->mutable_graph_options()
307+
->mutable_optimizer_options()
308+
->set_device_placement_optimization(true);
309+
}
310+
297311
// session num
298312
model_session_config->set_session_num(
299313
server_options.session_num_per_group);

tensorflow_serving/model_servers/server.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ class Server {
100100
tensorflow::string oss_access_id = "";
101101
tensorflow::string oss_access_key = "";
102102

103+
bool enable_device_placement_optimization = false;
104+
103105
Options();
104106
};
105107

0 commit comments

Comments
 (0)