Skip to content

Commit 91ac4b1

Browse files
authored
[DeepRec] Support GetModelMetadata request in session_group. (#12)
1 parent 359e038 commit 91ac4b1

File tree

4 files changed

+45
-16
lines changed

4 files changed

+45
-16
lines changed

tensorflow_serving/model_servers/http_rest_api_handler.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,11 @@ Status HttpRestApiHandler::ProcessModelMetadataRequest(
270270
request.mutable_model_spec()->mutable_version()->set_value(version);
271271
}
272272

273+
ModelMetaOption opt;
274+
opt.use_session_group = use_session_group_;
273275
GetModelMetadataResponse response;
274276
TF_RETURN_IF_ERROR(
275-
GetModelMetadataImpl::GetModelMetadata(core_, request, &response));
277+
GetModelMetadataImpl::GetModelMetadata(core_, request, &response, opt));
276278
JsonPrintOptions opts;
277279
opts.add_whitespace = true;
278280
opts.always_print_primitive_fields = true;

tensorflow_serving/model_servers/prediction_service_impl.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,10 @@ ::grpc::Status PredictionServiceImpl::GetModelMetadata(
6161
errors::InvalidArgument("GetModelMetadata API is only available when "
6262
"use_saved_model is set to true"));
6363
}
64+
ModelMetaOption opt;
65+
opt.use_session_group = use_session_group_;
6466
const ::grpc::Status status = ToGRPCStatus(
65-
GetModelMetadataImpl::GetModelMetadata(core_, *request, response));
67+
GetModelMetadataImpl::GetModelMetadata(core_, *request, response, opt));
6668
if (!status.ok()) {
6769
VLOG(1) << "GetModelMetadata failed: " << status.error_message();
6870
}

tensorflow_serving/servables/tensorflow/get_model_metadata_impl.cc

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,34 @@ Status ValidateGetModelMetadataRequest(const GetModelMetadataRequest& request) {
4949

5050
Status SavedModelGetSignatureDef(ServerCore* core, const ModelSpec& model_spec,
5151
const GetModelMetadataRequest& request,
52-
GetModelMetadataResponse* response) {
53-
ServableHandle<SavedModelBundle> bundle;
54-
TF_RETURN_IF_ERROR(core->GetServableHandle(model_spec, &bundle));
52+
GetModelMetadataResponse* response,
53+
ModelMetaOption opt) {
5554
SignatureDefMap signature_def_map;
56-
for (const auto& signature : bundle->meta_graph_def.signature_def()) {
57-
(*signature_def_map.mutable_signature_def())[signature.first] =
58-
signature.second;
55+
std::string model_name;
56+
int64 version;
57+
if (opt.use_session_group) {
58+
ServableHandle<SavedModelBundleV2> bundle;
59+
TF_RETURN_IF_ERROR(core->GetServableHandle(model_spec, &bundle));
60+
for (const auto& signature : bundle->meta_graph_def.signature_def()) {
61+
(*signature_def_map.mutable_signature_def())[signature.first] =
62+
signature.second;
63+
}
64+
model_name = bundle.id().name;
65+
version = bundle.id().version;
66+
} else {
67+
ServableHandle<SavedModelBundle> bundle;
68+
TF_RETURN_IF_ERROR(core->GetServableHandle(model_spec, &bundle));
69+
for (const auto& signature : bundle->meta_graph_def.signature_def()) {
70+
(*signature_def_map.mutable_signature_def())[signature.first] =
71+
signature.second;
72+
}
73+
model_name = bundle.id().name;
74+
version = bundle.id().version;
5975
}
76+
6077
auto response_model_spec = response->mutable_model_spec();
61-
response_model_spec->set_name(bundle.id().name);
62-
response_model_spec->mutable_version()->set_value(bundle.id().version);
78+
response_model_spec->set_name(model_name);
79+
response_model_spec->mutable_version()->set_value(version);
6380

6481
(*response->mutable_metadata())[GetModelMetadataImpl::kSignatureDef].PackFrom(
6582
signature_def_map);
@@ -72,24 +89,26 @@ constexpr const char GetModelMetadataImpl::kSignatureDef[];
7289

7390
Status GetModelMetadataImpl::GetModelMetadata(
7491
ServerCore* core, const GetModelMetadataRequest& request,
75-
GetModelMetadataResponse* response) {
92+
GetModelMetadataResponse* response,
93+
ModelMetaOption opt) {
7694
if (!request.has_model_spec()) {
7795
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
7896
"Missing ModelSpec");
7997
}
8098
return GetModelMetadataWithModelSpec(core, request.model_spec(), request,
81-
response);
99+
response, opt);
82100
}
83101

84102
Status GetModelMetadataImpl::GetModelMetadataWithModelSpec(
85103
ServerCore* core, const ModelSpec& model_spec,
86104
const GetModelMetadataRequest& request,
87-
GetModelMetadataResponse* response) {
105+
GetModelMetadataResponse* response,
106+
ModelMetaOption opt) {
88107
TF_RETURN_IF_ERROR(ValidateGetModelMetadataRequest(request));
89108
for (const auto& metadata_field : request.metadata_field()) {
90109
if (metadata_field == kSignatureDef) {
91110
TF_RETURN_IF_ERROR(
92-
SavedModelGetSignatureDef(core, model_spec, request, response));
111+
SavedModelGetSignatureDef(core, model_spec, request, response, opt));
93112
} else {
94113
return tensorflow::errors::InvalidArgument(
95114
"MetadataField %s is not supported", metadata_field);

tensorflow_serving/servables/tensorflow/get_model_metadata_impl.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,26 @@ limitations under the License.
2323
namespace tensorflow {
2424
namespace serving {
2525

26+
struct ModelMetaOption {
27+
bool use_session_group = false;
28+
};
29+
2630
class GetModelMetadataImpl {
2731
public:
2832
static constexpr const char kSignatureDef[] = "signature_def";
2933

3034
static Status GetModelMetadata(ServerCore* core,
3135
const GetModelMetadataRequest& request,
32-
GetModelMetadataResponse* response);
36+
GetModelMetadataResponse* response,
37+
ModelMetaOption opt = ModelMetaOption());
3338

3439
// Like GetModelMetadata(), but uses 'model_spec' instead of the one embedded
3540
// in 'request'.
3641
static Status GetModelMetadataWithModelSpec(
3742
ServerCore* core, const ModelSpec& model_spec,
3843
const GetModelMetadataRequest& request,
39-
GetModelMetadataResponse* response);
44+
GetModelMetadataResponse* response,
45+
ModelMetaOption opt = ModelMetaOption());
4046
};
4147

4248
} // namespace serving

0 commit comments

Comments
 (0)