Skip to content

Commit 5017b0f

Browse files
authored
[DeepRec] Support session_group to use multiple GPUs. (#15)
1 parent 37f7ba4 commit 5017b0f

File tree

6 files changed

+29
-13
lines changed

6 files changed

+29
-13
lines changed

WORKSPACE

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ load("//tensorflow_serving:repo.bzl", "tensorflow_http_archive")
1111

1212
tensorflow_http_archive(
1313
name = "org_tensorflow",
14-
sha256 = "f11c490ff7cb90fba7e1fa31822de425e3243bddc767bc5ebf7ae0d373bfdbc1",
15-
git_commit = "dcdc32e2e565b7a815a5906a7e1e1df202322d37",
14+
sha256 = "836a1a5425a4f853abfe1dd03e2222bb9004540163fc2c56d57979b390cd5a31",
15+
git_commit = "c3a656197cae9647fad47ec7d84ccd722a502815",
1616
)
1717

1818
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

tensorflow_serving/model_servers/main.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,10 @@ int main(int argc, char** argv) {
160160
&options.session_num_per_group,
161161
"Session num for a session group, "
162162
"default 0 means we not use session group."),
163+
tensorflow::Flag("gpu_ids_list",
164+
&options.gpu_ids_list,
165+
"GPU id list for a session group, "
166+
"default '', the legal format is '0,1,2...'."),
163167
tensorflow::Flag(
164168
"ssl_config_file", &options.ssl_config_file,
165169
"If non-empty, read an ascii SSLConfig protobuf from "

tensorflow_serving/model_servers/server.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,15 @@ Status CreatePlatformConfigMapV2(const Server::Options& server_options,
264264
model_session_config->set_session_num(
265265
server_options.session_num_per_group);
266266

267+
// gpu_ids_list for a session group
268+
if (!server_options.gpu_ids_list.empty()) {
269+
std::vector<string> ids =
270+
str_util::Split(server_options.gpu_ids_list, ',');
271+
for (auto id : ids) {
272+
model_session_config->add_gpu_ids(std::stoi(id));
273+
}
274+
}
275+
267276
// use_per_session_threads
268277
model_session_config->mutable_session_config()
269278
->set_use_per_session_threads(
@@ -366,7 +375,7 @@ Status Server::BuildAndStart(const Options& server_options) {
366375
if (server_options.model_config_file.empty()) {
367376
options.model_server_config = BuildSingleModelConfig(
368377
server_options.model_name, server_options.model_base_path);
369-
use_session_group = server_options.session_num_per_group > 1;
378+
use_session_group = server_options.session_num_per_group > 0;
370379
} else {
371380
TF_RETURN_IF_ERROR(ParseProtoTextFile<ModelServerConfig>(
372381
server_options.model_config_file, &options.model_server_config));

tensorflow_serving/model_servers/server.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class Server {
8888
bool use_per_session_threads = false;
8989
bool use_session_group = false;
9090
tensorflow::int32 session_num_per_group = 0;
91+
tensorflow::string gpu_ids_list = "";
9192
bool use_multi_stream = false;
9293

9394
Options();

tensorflow_serving/servables/tensorflow/bundle_factory_util.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,18 @@ SessionGroupOptions GetSessionOptions(const SessionGroupBundleConfig& config, in
8484
<< config.model_session_config_size();
8585
}
8686
options.config = config.model_session_config()[model_id].session_config();
87-
options.metadata.session_num = config.model_session_config()[model_id].session_num();
88-
if (options.metadata.session_num == 0) {
87+
options.metadata.session_count = config.model_session_config()[model_id].session_num();
88+
if (options.metadata.session_count == 0) {
8989
LOG(WARNING) << "User set use_session_group=true, but the #" << model_id
9090
<< " model config don't contain session_num field, "
9191
<< "please check platform_config_file config file. "
9292
<< "Now use default value 1.";
93-
options.metadata.session_num = 1;
93+
options.metadata.session_count = 1;
9494
}
95-
for (auto& conf : config.model_session_config()) {
96-
options.metadata.streams_vec.emplace_back(conf.session_num());
95+
if (!config.model_session_config()[model_id].gpu_ids().empty()) {
96+
for (auto id : config.model_session_config()[model_id].gpu_ids()) {
97+
options.metadata.gpu_ids.emplace_back(id);
98+
}
9799
}
98100
options.metadata.model_id = model_id;
99101
options.metadata.cpusets = config.model_session_config()[model_id].cpusets();

tensorflow_serving/servables/tensorflow/session_bundle_config.proto

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,13 @@ message SessionGroupBundleConfig {
170170
// See details at tensorflow/core/protobuf/config.proto.
171171
ConfigProto session_config = 1;
172172
// session_num of the session group.
173-
int32 session_num = 2;
174-
// gpu num
175-
//int32 gpu_num = 3;
173+
int32 session_num = 2;// [default = 1];
174+
// gpu list for current session_group
175+
repeated int32 gpu_ids = 3;
176176
// model id in multi-models
177-
int32 model_id = 3;
177+
int32 model_id = 4;
178178
// cpusets
179-
string cpusets = 4;
179+
string cpusets = 5;
180180
}
181181

182182
repeated ModelSessionConfig model_session_config = 784;

0 commit comments

Comments
 (0)