66import grpc
77import requests
88import mii
9- from mii .utils import get_task
10- from mii .grpc_related .proto import modelresponse_pb2 , modelresponse_pb2_grpc
11- from mii .constants import GRPC_MAX_MSG_SIZE , Tasks , DeploymentType
12- from mii .method_table import GRPC_METHOD_TABLE
9+ from .grpc_related .proto import modelresponse_pb2 , modelresponse_pb2_grpc
10+ from .constants import GRPC_MAX_MSG_SIZE , TaskType , DeploymentType
11+ from .method_table import GRPC_METHOD_TABLE
12+ from .config import MIIConfig
13+ from .utils import import_score_file
1314
1415
15- def _get_deployment_info (deployment_name ):
16- configs = mii .utils .import_score_file (deployment_name , DeploymentType .LOCAL ).configs
17- task = configs [mii .constants .TASK_NAME_KEY ]
18- mii_configs_dict = configs [mii .constants .MII_CONFIGS_KEY ]
19- mii_configs = mii .config .MIIConfig (** mii_configs_dict )
20-
21- assert task is not None , "The task name should be set before calling init"
22- return task , mii_configs
16+ def _get_mii_config (deployment_name ):
17+ mii_config = import_score_file (deployment_name , DeploymentType .LOCAL ).mii_config
18+ return MIIConfig (** mii_config )
2319
2420
2521def mii_query_handle (deployment_name ):
@@ -39,27 +35,33 @@ def mii_query_handle(deployment_name):
3935 inference_pipeline , task = mii .non_persistent_models [deployment_name ]
4036 return MIINonPersistentClient (task , deployment_name )
4137
42- task_name , mii_configs = _get_deployment_info (deployment_name )
43- return MIIClient (task_name , "localhost" , mii_configs .port_number )
38+ mii_config = _get_mii_config (deployment_name )
39+ return MIIClient (mii_config .model_config .task ,
40+ "localhost" , # TODO: This can probably be removed
41+ mii_config .port_number )
4442
4543
4644def create_channel (host , port ):
47- return grpc .aio .insecure_channel (f'{ host } :{ port } ' ,
48- options = [('grpc.max_send_message_length' ,
49- GRPC_MAX_MSG_SIZE ),
50- ('grpc.max_receive_message_length' ,
51- GRPC_MAX_MSG_SIZE )])
52-
53-
54- class MIIClient ():
45+ return grpc .aio .insecure_channel (
46+ f"{ host } :{ port } " ,
47+ options = [
48+ ("grpc.max_send_message_length" ,
49+ GRPC_MAX_MSG_SIZE ),
50+ ("grpc.max_receive_message_length" ,
51+ GRPC_MAX_MSG_SIZE ),
52+ ],
53+ )
54+
55+
56+ class MIIClient :
5557 """
5658 Client to send queries to a single endpoint.
5759 """
58- def __init__ (self , task_name , host , port ):
60+ def __init__ (self , task , host , port ):
5961 self .asyncio_loop = asyncio .get_event_loop ()
6062 channel = create_channel (host , port )
6163 self .stub = modelresponse_pb2_grpc .ModelResponseStub (channel )
62- self .task = get_task ( task_name )
64+ self .task = task
6365
6466 async def _request_async_response (self , request_dict , ** query_kwargs ):
6567 if self .task not in GRPC_METHOD_TABLE :
@@ -87,7 +89,9 @@ async def create_session_async(self, session_id):
8789 modelresponse_pb2 .SessionID (session_id = session_id ))
8890
8991 def create_session (self , session_id ):
90- assert self .task == Tasks .TEXT_GENERATION , f"Session creation only available for task '{ Tasks .TEXT_GENERATION } '."
92+ assert (
93+ self .task == TaskType .TEXT_GENERATION
94+ ), f"Session creation only available for task '{ TaskType .TEXT_GENERATION } '."
9195 return self .asyncio_loop .run_until_complete (
9296 self .create_session_async (session_id ))
9397
@@ -96,89 +100,40 @@ async def destroy_session_async(self, session_id):
96100 )
97101
98102 def destroy_session (self , session_id ):
99- assert self .task == Tasks .TEXT_GENERATION , f"Session deletion only available for task '{ Tasks .TEXT_GENERATION } '."
103+ assert (
104+ self .task == TaskType .TEXT_GENERATION
105+ ), f"Session deletion only available for task '{ TaskType .TEXT_GENERATION } '."
100106 self .asyncio_loop .run_until_complete (self .destroy_session_async (session_id ))
101107
102108
103- class MIITensorParallelClient ():
104- """
105- Client to send queries to multiple endpoints in parallel.
106- This is used to call multiple servers deployed for tensor parallelism.
107- """
108- def __init__ (self , task_name , host , ports ):
109- self .task = get_task (task_name )
110- self .clients = [MIIClient (task_name , host , port ) for port in ports ]
111- self .asyncio_loop = asyncio .get_event_loop ()
112-
113- # runs task in parallel and return the result from the first task
114- async def _query_in_tensor_parallel (self , request_string , query_kwargs ):
115- responses = []
116- for client in self .clients :
117- responses .append (
118- self .asyncio_loop .create_task (
119- client ._request_async_response (request_string ,
120- ** query_kwargs )))
121-
122- await responses [0 ]
123- return responses [0 ]
124-
125- def query (self , request_dict , ** query_kwargs ):
126- """Query a local deployment:
127-
128- mii/examples/local/gpt2-query-example.py
129- mii/examples/local/roberta-qa-query-example.py
130-
131- Arguments:
132- request_dict: A task specific request dictionary consisting of the inputs to the models
133- query_kwargs: additional query parameters for the model
134-
135- Returns:
136- response: Response of the model
137- """
138- response = self .asyncio_loop .run_until_complete (
139- self ._query_in_tensor_parallel (request_dict ,
140- query_kwargs ))
141- ret = response .result ()
142- return ret
143-
144- def terminate (self ):
145- """Terminates the deployment"""
146- for client in self .clients :
147- client .terminate ()
148-
149- def create_session (self , session_id ):
150- for client in self .clients :
151- client .create_session (session_id )
152-
153- def destroy_session (self , session_id ):
154- for client in self .clients :
155- client .destroy_session (session_id )
156-
157-
158- class MIINonPersistentClient ():
109+ class MIINonPersistentClient :
159110 def __init__ (self , task , deployment_name ):
160111 self .task = task
161112 self .deployment_name = deployment_name
162113
163114 def query (self , request_dict , ** query_kwargs ):
164- assert self .deployment_name in mii .non_persistent_models , f"deployment: { self .deployment_name } not found"
115+ assert (
116+ self .deployment_name in mii .non_persistent_models
117+ ), f"deployment: { self .deployment_name } not found"
165118 task_methods = GRPC_METHOD_TABLE [self .task ]
166119 inference_pipeline = mii .non_persistent_models [self .deployment_name ][0 ]
167120
168- if self .task == Tasks .QUESTION_ANSWERING :
169- if 'question' not in request_dict or 'context' not in request_dict :
121+ # TODO: refactor so this code is shared between non-persistent and
122+ # persistent deployments in method_table.py
123+ if self .task == TaskType .QUESTION_ANSWERING :
124+ if "question" not in request_dict or "context" not in request_dict :
170125 raise Exception (
171126 "Question Answering Task requires 'question' and 'context' keys" )
172127 args = (request_dict ["question" ], request_dict ["context" ])
173128 kwargs = query_kwargs
174129
175- elif self .task == Tasks .CONVERSATIONAL :
176- conv = task_methods .create_conversation (request_dict , ** query_kwargs )
130+ elif self .task == TaskType .CONVERSATIONAL :
131+ conv = task_methods .create_conversation (request_dict )
177132 args = (conv , )
178- kwargs = {}
133+ kwargs = query_kwargs
179134
180135 else :
181- args = (request_dict [' query' ], )
136+ args = (request_dict [" query" ], )
182137 kwargs = query_kwargs
183138
184139 return task_methods .run_inference (inference_pipeline , args , query_kwargs )
@@ -189,6 +144,6 @@ def terminate(self):
189144
190145
191146def terminate_restful_gateway (deployment_name ):
192- _ , mii_configs = _get_deployment_info (deployment_name )
193- if mii_configs .enable_restful_api :
194- requests .get (f"http://localhost:{ mii_configs .restful_api_port } /terminate" )
147+ mii_config = _get_mii_config (deployment_name )
148+ if mii_config .enable_restful_api :
149+ requests .get (f"http://localhost:{ mii_config .restful_api_port } /terminate" )
0 commit comments