Skip to content

Commit 1e60069

Browse files
authored
[DeepRec] Support multi-model deployment in SessionGroup. (#13)
1 parent 91ac4b1 commit 1e60069

23 files changed

+195
-55
lines changed

WORKSPACE

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ load("//tensorflow_serving:repo.bzl", "tensorflow_http_archive")
1111

1212
tensorflow_http_archive(
1313
name = "org_tensorflow",
14-
sha256 = "46ad5154cec11995d5feba1401b6d0d72be457e48a9776bfce04b13c009ac412",
15-
git_commit = "0fe26688a57eee31bda56a57d8f05e7071c78c9b",
14+
sha256 = "95c2e401a57024a57fcf757498f0e962519e6967381b630d6afc3a4a80c7ee37",
15+
git_commit = "4b8c11fa0c72e1b3483ef8b1960c4790dda0e437",
1616
)
1717

1818
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

tensorflow_serving/config/model_server_config.proto

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ message ModelConfig {
6565
//
6666
// (This can be changed once a model is in serving.)
6767
LoggingConfig logging_config = 6;
68+
69+
// model_id in multi-models.
70+
int32 model_id = 9;
6871
}
6972

7073
// Static list of models to be loaded for serving.

tensorflow_serving/core/servable_data.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,16 @@ class ServableData {
6565
// !this->ok().
6666
T ConsumeDataOrDie();
6767

68+
void SetModelId(int id) { model_id_ = id; }
69+
int GetModelId() const { return model_id_; }
70+
6871
private:
6972
ServableData() = delete;
7073

7174
const ServableId id_;
7275
const Status status_;
7376
T data_;
77+
int model_id_ = 0;
7478
};
7579

7680
// Helper static method to create a ServableData object. Caller may skip

tensorflow_serving/core/source_adapter.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ class UnarySourceAdapter : public SourceAdapter<InputType, OutputType> {
135135
// Converts a single InputType instance into a corresponding OutputType
136136
// instance.
137137
virtual Status Convert(const InputType& data, OutputType* converted_data) = 0;
138+
virtual Status Convert(const InputType& data, int model_id,
139+
OutputType* converted_data) {
140+
return Convert(data, converted_data);
141+
}
138142
};
139143

140144
// A source adapter that converts every incoming ServableData<InputType> item
@@ -210,7 +214,8 @@ UnarySourceAdapter<InputType, OutputType>::Adapt(
210214
for (const ServableData<InputType>& version : versions) {
211215
if (version.status().ok()) {
212216
OutputType adapted_data;
213-
Status adapt_status = Convert(version.DataOrDie(), &adapted_data);
217+
Status adapt_status = Convert(version.DataOrDie(),
218+
version.GetModelId(), &adapted_data);
214219
if (adapt_status.ok()) {
215220
adapted_versions.emplace_back(
216221
ServableData<OutputType>{version.id(), std::move(adapted_data)});

tensorflow_serving/core/test_util/session_test_util.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class DelegatingSessionFactory : public SessionFactory {
6262

6363
Status NewSessionGroup(const SessionOptions& options,
6464
SessionGroup** out_session_group,
65-
int session_num = 1) {
65+
const SessionGroupMetadata& metadata) {
6666
return errors::Internal(
6767
"NewSessionGroup method not implemented in DelegatingSessionFactory.");
6868
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Usage:
2+
3+
CUDA_VISIBLE_DEVICES=0,1 tensorflow_model_server --use_session_group=true --model_config_file=session_group_multi_models_config --platform_config_file=session_group_multi_models_platform_config
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
model_config_list:{
2+
config:{
3+
name:"pb1",
4+
base_path:"/data/workspace/serving-model/multi_wdl_model/pb1",
5+
model_platform:"tensorflow"
6+
},
7+
config:{
8+
name:"pb2",
9+
base_path:"/data/workspace/serving-model/multi_wdl_model/pb2",
10+
model_platform:"tensorflow"
11+
},
12+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
platform_configs {
2+
key: "tensorflow"
3+
value {
4+
source_adapter_config {
5+
[type.googleapis.com/tensorflow.serving.SavedModelBundleSourceAdapterConfig] {
6+
legacy_config {
7+
session_config {
8+
gpu_options {
9+
allow_growth: true
10+
}
11+
}
12+
}
13+
}
14+
}
15+
}
16+
}
17+
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
model_config_list:{
2+
config:{
3+
name:"pb1",
4+
base_path:"/data/workspace/serving-model/multi_wdl_model/pb1",
5+
model_platform:"tensorflow",
6+
model_id: 0
7+
},
8+
config:{
9+
name:"pb2",
10+
base_path:"/data/workspace/serving-model/multi_wdl_model/pb2",
11+
model_platform:"tensorflow",
12+
model_id: 1
13+
},
14+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
platform_configs {
2+
key: "tensorflow"
3+
value {
4+
source_adapter_config {
5+
[type.googleapis.com/tensorflow.serving.SavedModelBundleV2SourceAdapterConfig] {
6+
legacy_config {
7+
model_session_config {
8+
session_config {
9+
gpu_options {
10+
allow_growth: true
11+
}
12+
intra_op_parallelism_threads: 8
13+
inter_op_parallelism_threads: 8
14+
use_per_session_threads: true
15+
use_per_session_stream: true
16+
}
17+
session_num: 2
18+
}
19+
model_session_config {
20+
session_config {
21+
gpu_options {
22+
allow_growth: true
23+
}
24+
intra_op_parallelism_threads: 16
25+
inter_op_parallelism_threads: 16
26+
use_per_session_threads: true
27+
use_per_session_stream: true
28+
}
29+
session_num: 4
30+
}
31+
}
32+
}
33+
}
34+
}
35+
}
36+

0 commit comments

Comments
 (0)