Skip to content

Commit 1a75629

Browse files
authored
Feature/remote url (#117)
* add requests to `requirements.txt` * add a valid `.json` file to be used in * add download function * integrate download from url * add download from url test cases * fix exception raise during test * add `Retry` mechanism * remove unused imports + removing trailing whitespaces * update docstring for url * refactor error messages and add them to `pymilo_param.py`
1 parent a3cffac commit 1a75629

File tree

9 files changed

+146
-25
lines changed

9 files changed

+146
-25
lines changed

dev-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
numpy==2.0.1
22
scikit-learn==1.5.1
33
scipy>=0.19.1
4+
requests>=2.0.0
45
setuptools>=40.8.0
56
vulture>=1.0
67
bandit>=1.5.1

pymilo/exceptions/deserialize_exception.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self, meta_data):
4242
error_type = meta_data['error_type']
4343
error_type_to_message = {
4444
DeserializationErrorTypes.CORRUPTED_JSON_FILE:
45-
'the given json file is not a valid .json file.',
45+
'the given file is not a valid .json file.',
4646
DeserializationErrorTypes.INVALID_MODEL:
4747
'the given model is not supported or is not a valid model.',
4848
DeserializationErrorTypes.VALID_MODEL_INVALID_INTERNAL_STRUCTURE:

pymilo/pymilo_obj.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
# -*- coding: utf-8 -*-
22
"""PyMilo modules."""
3-
from .pymilo_func import get_sklearn_data, get_sklearn_version, to_sklearn_model
4-
from .utils.util import get_sklearn_type
5-
from .pymilo_param import PYMILO_VERSION, PYMILO_VERSION_DOES_NOT_EXIST, UNEQUAL_PYMILO_VERSIONS, UNEQUAL_SKLEARN_VERSIONS
63
import json
7-
8-
from .exceptions.deserialize_exception import PymiloDeserializationException, DeserializationErrorTypes
9-
from .exceptions.serialize_exception import PymiloSerializationException, SerializationErrorTypes
10-
from traceback import format_exc
11-
from warnings import warn
124
from copy import deepcopy
5+
from warnings import warn
6+
from traceback import format_exc
7+
from .utils.util import get_sklearn_type, download_model
8+
from .pymilo_func import get_sklearn_data, get_sklearn_version, to_sklearn_model
9+
from .exceptions.serialize_exception import PymiloSerializationException, SerializationErrorTypes
10+
from .exceptions.deserialize_exception import PymiloDeserializationException, DeserializationErrorTypes
11+
from .pymilo_param import PYMILO_VERSION, UNEQUAL_PYMILO_VERSIONS, UNEQUAL_SKLEARN_VERSIONS, INVALID_IMPORT_INIT_PARAMS
1312

1413

1514
class Export:
@@ -84,25 +83,29 @@ class Import:
8483
>>> imported_sklearn_model.predict(x_test)
8584
"""
8685

87-
def __init__(self, file_adr, json_dump=None):
86+
def __init__(self, file_adr=None, json_dump=None, url=None):
8887
"""
8988
Initialize the Pymilo Import instance.
9089
9190
:param file_adr: the file path where the serialized model's JSON file is located.
92-
:type file_adr: string
91+
:type file_adr: str or None
9392
:param json_dump: the json dump of the associated model, it can be None(reading from the file_adr)
9493
:type json_dump: str or None
94+
:param url: url to exported JSON file
95+
:type: str or None
9596
:return: an instance of the Pymilo Import class
9697
"""
9798
serialized_model_obj = None
99+
if url is not None:
100+
serialized_model_obj = download_model(url)
101+
elif json_dump is not None and isinstance(json_dump, str):
102+
serialized_model_obj = json.loads(json_dump)
103+
elif file_adr is not None:
104+
with open(file_adr, 'r') as fp:
105+
serialized_model_obj = json.load(fp)
106+
else:
107+
raise Exception(INVALID_IMPORT_INIT_PARAMS)
98108
try:
99-
if json_dump and isinstance(json_dump, str):
100-
serialized_model_obj = json.loads(json_dump)
101-
else:
102-
with open(file_adr, 'r') as fp:
103-
serialized_model_obj = json.load(fp)
104-
if "pymilo_version" not in serialized_model_obj:
105-
raise Exception(PYMILO_VERSION_DOES_NOT_EXIST)
106109
if not serialized_model_obj["pymilo_version"] == PYMILO_VERSION:
107110
warn(UNEQUAL_PYMILO_VERSIONS, category=Warning)
108111
if not serialized_model_obj["sklearn_version"] == get_sklearn_version():
@@ -114,17 +117,19 @@ def __init__(self, file_adr, json_dump=None):
114117
json_content = None
115118
if json_dump and isinstance(json_dump, str):
116119
json_content = json_dump
117-
else:
120+
elif file_adr is not None:
118121
with open(file_adr) as f:
119122
json_content = f.readlines()
123+
else:
124+
json_content = serialized_model_obj
120125
raise PymiloDeserializationException(
121126
{
122127
'json_file': json_content,
123128
'error_type': DeserializationErrorTypes.CORRUPTED_JSON_FILE,
124129
'error': {
125130
'Exception': repr(e),
126131
'Traceback': format_exc()},
127-
'object': serialized_model_obj})
132+
'object': ""})
128133

129134
def to_model(self):
130135
"""

pymilo/pymilo_param.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@
8787
PYMILO_VERSION_DOES_NOT_EXIST = "Corrupted JSON file, `pymilo_version` doesn't exist in this file."
8888
UNEQUAL_PYMILO_VERSIONS = "warning: Installed PyMilo version differs from the PyMilo version used to create the JSON file."
8989
UNEQUAL_SKLEARN_VERSIONS = "warning: Installed Scikit version differs from the Scikit version used to create the JSON file and it may prevent PyMilo from transporting seamlessly."
90+
INVALID_IMPORT_INIT_PARAMS = "Invalid input parameters, you should either pass a valid file_adr or a json_dump or a url to initiate Import class."
91+
DOWNLOAD_MODEL_FAILED = "Failed to download the JSON file, Server didn't respond."
92+
INVALID_DOWNLOADED_MODEL = "The downloaded content is not a valid JSON file."
93+
9094

9195
SKLEARN_LINEAR_MODEL_TABLE = {
9296
"DummyRegressor": dummy.DummyRegressor,

pymilo/utils/util.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# -*- coding: utf-8 -*-
22
"""utility module."""
3-
from inspect import signature
3+
import requests
44
import importlib
5+
from inspect import signature
6+
from ..pymilo_param import DOWNLOAD_MODEL_FAILED, INVALID_DOWNLOADED_MODEL
57

68

79
def get_sklearn_type(model):
@@ -135,3 +137,31 @@ def prefix_list(list1, list2):
135137
if len(list1) < len(list2):
136138
return False
137139
return all(list1[j] == list2[j] for j in range(len(list2)))
140+
141+
142+
def download_model(url):
143+
"""
144+
Download the model from the given url.
145+
146+
:param url: url to exported JSON file
147+
:type url: str
148+
149+
:return: obj
150+
"""
151+
s = requests.Session()
152+
retries = requests.adapters.Retry(
153+
total=5,
154+
backoff_factor=0.1,
155+
status_forcelist=[500, 502, 503, 504]
156+
)
157+
s.mount('http://', requests.adapters.HTTPAdapter(max_retries=retries))
158+
s.mount('https://', requests.adapters.HTTPAdapter(max_retries=retries))
159+
try:
160+
response = s.get(url)
161+
except Exception:
162+
raise Exception(DOWNLOAD_MODEL_FAILED)
163+
try:
164+
if response.status_code == 200:
165+
return response.json()
166+
except ValueError:
167+
raise Exception(INVALID_DOWNLOADED_MODEL)

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
numpy>=1.9.0
22
scikit-learn>=0.22.2
33
scipy>=0.19.1
4+
requests>=2.0.0

tests/test_exceptions/import_exceptions.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# CORRUPTED_JSON_FILE = 1 -> tested.
22
# INVALID_MODEL = 2 -> tested.
33
# VALID_MODEL_INVALID_INTERNAL_STRUCTURE = 3 -> tested.
4+
import os
45
from pymilo.pymilo_obj import Import
56

6-
import os
77

88
def invalid_json(print_output = True):
99
json_files = ["corrupted", "unknown-model"]
@@ -16,4 +16,24 @@ def invalid_json(print_output = True):
1616
except Exception as e:
1717
if print_output: print("An Exception occured\n", e)
1818
return True
19-
19+
20+
def invalid_url():
21+
try:
22+
url = "https://invalid_url"
23+
Import(url=url)
24+
return False
25+
except Exception:
26+
return True
27+
28+
def valid_url_invalid_file():
29+
try:
30+
url = "https://filesamples.com/samples/code/json/sample1.json"
31+
Import(url=url)
32+
return False
33+
except Exception:
34+
return True
35+
36+
# def valid_url_valid_file():
37+
# with pytest.raises(Exception):
38+
# url = "https://raw.githubusercontent.com/openscilab/pymilo/main/tests/test_exceptions/valid_jsons/linear_regression.json"
39+
# Import(url=url)

tests/test_exceptions/test_exceptions.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,15 @@
33
from export_exceptions import valid_model_invalid_structure_neural_network
44
from export_exceptions import valid_model_irrelevant_chain
55

6-
from import_exceptions import invalid_json
6+
from import_exceptions import invalid_json, invalid_url, valid_url_invalid_file#, valid_url_valid_file
77

88
EXCEPTION_TESTS = {
9-
'IMPORT': [invalid_json],
9+
'IMPORT': [
10+
invalid_json,
11+
invalid_url,
12+
valid_url_invalid_file,
13+
#valid_url_valid_file,
14+
],
1015
'EXPORT': [
1116
invalid_model,
1217
valid_model_invalid_structure_linear_model,
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
{
2+
"data": {
3+
"fit_intercept": true,
4+
"copy_X": true,
5+
"n_jobs": null,
6+
"positive": false,
7+
"n_features_in_": 10,
8+
"coef_": {
9+
"pymiloed-ndarray-list": [
10+
0.30609424754267966,
11+
-237.63557011300716,
12+
510.53804765114097,
13+
327.7298779909887,
14+
-814.1119263534517,
15+
492.7995945034062,
16+
102.84123996793083,
17+
184.6034960903708,
18+
743.5093875957093,
19+
76.09664636971895
20+
],
21+
"pymiloed-ndarray-dtype": "float64",
22+
"pymiloed-ndarray-shape": [
23+
10
24+
],
25+
"pymiloed-data-structure": "numpy.ndarray"
26+
},
27+
"rank_": 10,
28+
"singular_": {
29+
"pymiloed-ndarray-list": [
30+
1.9578051002417796,
31+
1.1797491126040702,
32+
1.0755406405377144,
33+
0.9579192686906345,
34+
0.7980638292867588,
35+
0.7594342409324799,
36+
0.7216957209064547,
37+
0.6459380350140406,
38+
0.27271507089040337,
39+
0.0915832239699
40+
],
41+
"pymiloed-ndarray-dtype": "float64",
42+
"pymiloed-ndarray-shape": [
43+
10
44+
],
45+
"pymiloed-data-structure": "numpy.ndarray"
46+
},
47+
"intercept_": {
48+
"value": 152.76429169049118,
49+
"np-type": "numpy.float64"
50+
}
51+
},
52+
"sklearn_version": "1.3.0",
53+
"pymilo_version": "0.9",
54+
"model_type": "LinearRegression"
55+
}

0 commit comments

Comments
 (0)