Skip to content

Commit 44026a6

Browse files
authored
Refactor Configs (#218)
* refactoring configs
1 parent 0182fa5 commit 44026a6

23 files changed

+1038
-1199
lines changed

mii/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
from .client import MIIClient, mii_query_handle
88
from .deployment import deploy
99
from .terminate import terminate
10-
from .constants import DeploymentType, Tasks
10+
from .constants import DeploymentType, TaskType
1111
from .aml_related.utils import aml_output_path
12+
from .config import MIIConfig, ModelConfig
1213
from .utils import get_supported_models
13-
from .config import MIIConfig, LoadBalancerConfig
1414
from .grpc_related.proto import modelresponse_pb2_grpc
1515

1616
__version__ = "0.0.0"

mii/client.py

Lines changed: 47 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,16 @@
66
import grpc
77
import requests
88
import mii
9-
from mii.utils import get_task
10-
from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc
11-
from mii.constants import GRPC_MAX_MSG_SIZE, Tasks, DeploymentType
12-
from mii.method_table import GRPC_METHOD_TABLE
9+
from .grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc
10+
from .constants import GRPC_MAX_MSG_SIZE, TaskType, DeploymentType
11+
from .method_table import GRPC_METHOD_TABLE
12+
from .config import MIIConfig
13+
from .utils import import_score_file
1314

1415

15-
def _get_deployment_info(deployment_name):
16-
configs = mii.utils.import_score_file(deployment_name, DeploymentType.LOCAL).configs
17-
task = configs[mii.constants.TASK_NAME_KEY]
18-
mii_configs_dict = configs[mii.constants.MII_CONFIGS_KEY]
19-
mii_configs = mii.config.MIIConfig(**mii_configs_dict)
20-
21-
assert task is not None, "The task name should be set before calling init"
22-
return task, mii_configs
16+
def _get_mii_config(deployment_name):
17+
mii_config = import_score_file(deployment_name, DeploymentType.LOCAL).mii_config
18+
return MIIConfig(**mii_config)
2319

2420

2521
def mii_query_handle(deployment_name):
@@ -39,27 +35,33 @@ def mii_query_handle(deployment_name):
3935
inference_pipeline, task = mii.non_persistent_models[deployment_name]
4036
return MIINonPersistentClient(task, deployment_name)
4137

42-
task_name, mii_configs = _get_deployment_info(deployment_name)
43-
return MIIClient(task_name, "localhost", mii_configs.port_number)
38+
mii_config = _get_mii_config(deployment_name)
39+
return MIIClient(mii_config.model_config.task,
40+
"localhost", # TODO: This can probably be removed
41+
mii_config.port_number)
4442

4543

4644
def create_channel(host, port):
47-
return grpc.aio.insecure_channel(f'{host}:{port}',
48-
options=[('grpc.max_send_message_length',
49-
GRPC_MAX_MSG_SIZE),
50-
('grpc.max_receive_message_length',
51-
GRPC_MAX_MSG_SIZE)])
52-
53-
54-
class MIIClient():
45+
return grpc.aio.insecure_channel(
46+
f"{host}:{port}",
47+
options=[
48+
("grpc.max_send_message_length",
49+
GRPC_MAX_MSG_SIZE),
50+
("grpc.max_receive_message_length",
51+
GRPC_MAX_MSG_SIZE),
52+
],
53+
)
54+
55+
56+
class MIIClient:
5557
"""
5658
Client to send queries to a single endpoint.
5759
"""
58-
def __init__(self, task_name, host, port):
60+
def __init__(self, task, host, port):
5961
self.asyncio_loop = asyncio.get_event_loop()
6062
channel = create_channel(host, port)
6163
self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel)
62-
self.task = get_task(task_name)
64+
self.task = task
6365

6466
async def _request_async_response(self, request_dict, **query_kwargs):
6567
if self.task not in GRPC_METHOD_TABLE:
@@ -87,7 +89,9 @@ async def create_session_async(self, session_id):
8789
modelresponse_pb2.SessionID(session_id=session_id))
8890

8991
def create_session(self, session_id):
90-
assert self.task == Tasks.TEXT_GENERATION, f"Session creation only available for task '{Tasks.TEXT_GENERATION}'."
92+
assert (
93+
self.task == TaskType.TEXT_GENERATION
94+
), f"Session creation only available for task '{TaskType.TEXT_GENERATION}'."
9195
return self.asyncio_loop.run_until_complete(
9296
self.create_session_async(session_id))
9397

@@ -96,89 +100,40 @@ async def destroy_session_async(self, session_id):
96100
)
97101

98102
def destroy_session(self, session_id):
99-
assert self.task == Tasks.TEXT_GENERATION, f"Session deletion only available for task '{Tasks.TEXT_GENERATION}'."
103+
assert (
104+
self.task == TaskType.TEXT_GENERATION
105+
), f"Session deletion only available for task '{TaskType.TEXT_GENERATION}'."
100106
self.asyncio_loop.run_until_complete(self.destroy_session_async(session_id))
101107

102108

103-
class MIITensorParallelClient():
104-
"""
105-
Client to send queries to multiple endpoints in parallel.
106-
This is used to call multiple servers deployed for tensor parallelism.
107-
"""
108-
def __init__(self, task_name, host, ports):
109-
self.task = get_task(task_name)
110-
self.clients = [MIIClient(task_name, host, port) for port in ports]
111-
self.asyncio_loop = asyncio.get_event_loop()
112-
113-
# runs task in parallel and return the result from the first task
114-
async def _query_in_tensor_parallel(self, request_string, query_kwargs):
115-
responses = []
116-
for client in self.clients:
117-
responses.append(
118-
self.asyncio_loop.create_task(
119-
client._request_async_response(request_string,
120-
**query_kwargs)))
121-
122-
await responses[0]
123-
return responses[0]
124-
125-
def query(self, request_dict, **query_kwargs):
126-
"""Query a local deployment:
127-
128-
mii/examples/local/gpt2-query-example.py
129-
mii/examples/local/roberta-qa-query-example.py
130-
131-
Arguments:
132-
request_dict: A task specific request dictionary consisting of the inputs to the models
133-
query_kwargs: additional query parameters for the model
134-
135-
Returns:
136-
response: Response of the model
137-
"""
138-
response = self.asyncio_loop.run_until_complete(
139-
self._query_in_tensor_parallel(request_dict,
140-
query_kwargs))
141-
ret = response.result()
142-
return ret
143-
144-
def terminate(self):
145-
"""Terminates the deployment"""
146-
for client in self.clients:
147-
client.terminate()
148-
149-
def create_session(self, session_id):
150-
for client in self.clients:
151-
client.create_session(session_id)
152-
153-
def destroy_session(self, session_id):
154-
for client in self.clients:
155-
client.destroy_session(session_id)
156-
157-
158-
class MIINonPersistentClient():
109+
class MIINonPersistentClient:
159110
def __init__(self, task, deployment_name):
160111
self.task = task
161112
self.deployment_name = deployment_name
162113

163114
def query(self, request_dict, **query_kwargs):
164-
assert self.deployment_name in mii.non_persistent_models, f"deployment: {self.deployment_name} not found"
115+
assert (
116+
self.deployment_name in mii.non_persistent_models
117+
), f"deployment: {self.deployment_name} not found"
165118
task_methods = GRPC_METHOD_TABLE[self.task]
166119
inference_pipeline = mii.non_persistent_models[self.deployment_name][0]
167120

168-
if self.task == Tasks.QUESTION_ANSWERING:
169-
if 'question' not in request_dict or 'context' not in request_dict:
121+
# TODO: refactor so this code is shared between non-persistent and
122+
# persistent deployments in method_table.py
123+
if self.task == TaskType.QUESTION_ANSWERING:
124+
if "question" not in request_dict or "context" not in request_dict:
170125
raise Exception(
171126
"Question Answering Task requires 'question' and 'context' keys")
172127
args = (request_dict["question"], request_dict["context"])
173128
kwargs = query_kwargs
174129

175-
elif self.task == Tasks.CONVERSATIONAL:
176-
conv = task_methods.create_conversation(request_dict, **query_kwargs)
130+
elif self.task == TaskType.CONVERSATIONAL:
131+
conv = task_methods.create_conversation(request_dict)
177132
args = (conv, )
178-
kwargs = {}
133+
kwargs = query_kwargs
179134

180135
else:
181-
args = (request_dict['query'], )
136+
args = (request_dict["query"], )
182137
kwargs = query_kwargs
183138

184139
return task_methods.run_inference(inference_pipeline, args, query_kwargs)
@@ -189,6 +144,6 @@ def terminate(self):
189144

190145

191146
def terminate_restful_gateway(deployment_name):
192-
_, mii_configs = _get_deployment_info(deployment_name)
193-
if mii_configs.enable_restful_api:
194-
requests.get(f"http://localhost:{mii_configs.restful_api_port}/terminate")
147+
mii_config = _get_mii_config(deployment_name)
148+
if mii_config.enable_restful_api:
149+
requests.get(f"http://localhost:{mii_config.restful_api_port}/terminate")

0 commit comments

Comments
 (0)