4
4
from .encryptor import DummyEncryptor
5
5
from .compressor import DummyCompressor
6
6
from ..pymilo_obj import Export , Import
7
+ from .param import PYMILO_CLIENT_INVALID_MODE , PYMILO_CLIENT_MODEL_SYNCHED , \
8
+ PYMILO_CLIENT_LOCAL_MODEL_UPLOADED , PYMILO_CLIENT_LOCAL_MODEL_UPLOAD_FAILED , \
9
+ PYMILO_CLIENT_INVALID_ATTRIBUTE , PYMILO_CLIENT_FAILED_TO_DOWNLOAD_REMOTE_MODEL
7
10
from .communicator import RESTClientCommunicator
8
11
from ..transporters .general_data_structure_transporter import GeneralDataStructureTransporter
9
12
@@ -22,20 +25,17 @@ def __init__(
22
25
self ,
23
26
model = None ,
24
27
mode = Mode .LOCAL ,
25
- server = "http://127.0.0.1" ,
26
- port = 8000
27
- ):
28
+ server_url = "http://127.0.0.1:8000" ,
29
+ ):
28
30
"""
29
31
Initialize the Pymilo PymiloClient instance.
30
32
31
33
:param model: the ML model PyMiloClient wrapped around
32
34
:type model: Any
33
35
:param mode: the mode in which PymiloClient should work, either LOCAL mode or DELEGATE
34
36
:type mode: str (LOCAL|DELEGATE)
35
- :param server: the url to which PyMilo Server listens
36
- :type server: str
37
- :param port: the port to which PyMilo Server listens
38
- :type port: int
37
+ :param server_url: the url to which PyMilo Server listens
38
+ :type server_url: str
39
39
:return: an instance of the Pymilo PymiloClient class
40
40
"""
41
41
self ._client_id = "0x_client_id"
@@ -44,10 +44,7 @@ def __init__(
44
44
self ._mode = mode
45
45
self ._compressor = DummyCompressor ()
46
46
self ._encryptor = DummyEncryptor ()
47
- self ._communicator = RESTClientCommunicator (
48
- server_url = "{}:{}" .format (server , port )
49
- )
50
-
47
+ self ._communicator = RESTClientCommunicator (server_url )
51
48
52
49
def toggle_mode (self , mode = Mode .LOCAL ):
53
50
"""
@@ -56,41 +53,41 @@ def toggle_mode(self, mode=Mode.LOCAL):
56
53
:return: None
57
54
"""
58
55
if mode not in Mode .__members__ .values ():
59
- raise Exception ("Invalid mode, the given mode should be either `LOCAL`[default] or `DELEGATE`." )
60
- self ._mode = mode
56
+ raise Exception (PYMILO_CLIENT_INVALID_MODE )
57
+ if mode != self ._mode :
58
+ self ._mode = mode
61
59
62
60
def download (self ):
63
61
"""
64
62
Request for the remote ML model to download.
65
63
66
64
:return: None
67
65
"""
68
- response = self ._communicator .download ({
69
- "client_id" : self ._client_id ,
66
+ serialized_model = self ._communicator .download ({
67
+ "client_id" : self ._client_id ,
70
68
"model_id" : self ._model_id
71
69
})
72
- if response .status_code != 200 :
73
- print ("Remote model download failed." )
74
- print ("Remote model downloaded successfully." )
75
- serialized_model = response .json ()["payload" ]
70
+ if serialized_model is None :
71
+ print (PYMILO_CLIENT_FAILED_TO_DOWNLOAD_REMOTE_MODEL )
72
+ return
76
73
self ._model = Import (file_adr = None , json_dump = serialized_model ).to_model ()
77
- print ("Local model updated successfully." )
74
+ print (PYMILO_CLIENT_MODEL_SYNCHED )
78
75
79
76
def upload (self ):
80
77
"""
81
78
Upload the local ML model to the remote server.
82
79
83
80
:return: None
84
81
"""
85
- response = self ._communicator .upload ({
86
- "client_id" : self ._client_id ,
82
+ succeed = self ._communicator .upload ({
83
+ "client_id" : self ._client_id ,
87
84
"model_id" : self ._model_id ,
88
85
"model" : Export (self ._model ).to_json (),
89
86
})
90
- if response . status_code == 200 :
91
- print ("Local model uploaded successfully." )
87
+ if succeed :
88
+ print (PYMILO_CLIENT_LOCAL_MODEL_UPLOADED )
92
89
else :
93
- print ("Local model upload failed." )
90
+ print (PYMILO_CLIENT_LOCAL_MODEL_UPLOAD_FAILED )
94
91
95
92
def __getattr__ (self , attribute ):
96
93
"""
@@ -105,18 +102,31 @@ def __getattr__(self, attribute):
105
102
if attribute in dir (self ._model ):
106
103
return getattr (self ._model , attribute )
107
104
else :
108
- raise AttributeError ("This attribute doesn't exist in either PymiloClient or the inner ML model." )
105
+ raise AttributeError (PYMILO_CLIENT_INVALID_ATTRIBUTE )
109
106
elif self ._mode == Mode .DELEGATE :
110
107
gdst = GeneralDataStructureTransporter ()
108
+ response = self ._communicator .attribute_type (
109
+ self ._encryptor .encrypt (
110
+ self ._compressor .compress (
111
+ {
112
+ "client_id" : self ._client_id ,
113
+ "model_id" : self ._model_id ,
114
+ "attribute" : attribute ,
115
+ }
116
+ )
117
+ )
118
+ )
119
+ if response ["attribute type" ] == "field" :
120
+ return gdst .deserialize (response , "attribute value" , None )
121
+
111
122
def relayer (* args , ** kwargs ):
112
- print (f"Method '{ attribute } ' called with args: { args } and kwargs: { kwargs } " )
113
123
payload = {
114
124
"client_id" : self ._client_id ,
115
125
"model_id" : self ._model_id ,
116
126
'attribute' : attribute ,
117
127
'args' : args ,
118
128
'kwargs' : kwargs ,
119
- }
129
+ }
120
130
payload ["args" ] = gdst .serialize (payload , "args" , None )
121
131
payload ["kwargs" ] = gdst .serialize (payload , "kwargs" , None )
122
132
result = self ._communicator .attribute_call (
@@ -125,7 +135,6 @@ def relayer(*args, **kwargs):
125
135
payload
126
136
)
127
137
)
128
- ). json ()
138
+ )
129
139
return gdst .deserialize (result , "payload" , None )
130
- relayer .__doc__ = getattr (self ._model .__class__ , attribute ).__doc__
131
140
return relayer
0 commit comments