Skip to content

Commit 4ae10f0

Browse files
authored
Merge pull request #78 from hsj576/main
Merge of feature-lifelong-n branch
2 parents 4dadd17 + 8d832a8 commit 4ae10f0

File tree

248 files changed

+54262
-75
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

248 files changed

+54262
-75
lines changed

README_ospp.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# OSPP
2+
I changed the sedna source code to implement my algorithm.
3+
Please turn to https://github.com/kubeedge/sedna/pull/378 and https://github.com/nailtu30/sedna/blob/ospp-final/README_ospp.md for more information.

core/common/constant.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,46 @@ class ParadigmType(Enum):
3434
SINGLE_TASK_LEARNING = "singletasklearning"
3535
INCREMENTAL_LEARNING = "incrementallearning"
3636
MULTIEDGE_INFERENCE = "multiedgeinference"
37+
LIFELONG_LEARNING = "lifelonglearning"
3738

3839

3940
class ModuleType(Enum):
4041
"""
4142
Algorithm module type.
4243
"""
4344
BASEMODEL = "basemodel"
45+
46+
# HEM
4447
HARD_EXAMPLE_MINING = "hard_example_mining"
4548

49+
# STP
50+
TASK_DEFINITION = "task_definition"
51+
TASK_RELATIONSHIP_DISCOVERY = "task_relationship_discovery"
52+
TASK_ALLOCATION = "task_allocation"
53+
TASK_REMODELING = "task_remodeling"
54+
INFERENCE_INTEGRATE = "inference_integrate"
55+
56+
# KM
57+
TASK_UPDATE_DECISION = "task_update_decision"
58+
59+
# UTP
60+
UNSEEN_TASK_ALLOCATION = "unseen_task_allocation"
61+
62+
# UTD
63+
UNSEEN_SAMPLE_RECOGNITION = "unseen_sample_recognition"
64+
UNSEEN_SAMPLE_RE_RECOGNITION = "unseen_sample_re_recognition"
65+
4666

4767
class SystemMetricType(Enum):
4868
"""
4969
System metric type of ianvs.
5070
"""
71+
# pylint: disable=C0103
5172
SAMPLES_TRANSFER_RATIO = "samples_transfer_ratio"
73+
FWT = "FWT"
74+
BWT = "BWT"
75+
Task_Avg_Acc = "Task_Avg_Acc"
76+
Matrix = "Matrix"
5277

5378
class TestObjectType(Enum):
5479
"""

core/storymanager/rank/rank.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pandas as pd
2222

2323
from core.common import utils
24-
from core.storymanager.visualization import get_visualization_func
24+
from core.storymanager.visualization import get_visualization_func, draw_heatmap_picture
2525

2626

2727
# pylint: disable=R0902
@@ -107,7 +107,10 @@ def _get_all_module_types(cls, test_cases) -> list:
107107
def _get_algorithm_hyperparameters(cls, algorithm):
108108
hps = {}
109109
for module in algorithm.modules.values():
110-
hps.update(**module.hyperparameters)
110+
for name, value in module.hyperparameters.items():
111+
name = f"{module.type}-{name}"
112+
value = str(value)
113+
hps.update({name: value})
111114
return hps
112115

113116
def _get_all_hps_names(self, test_cases) -> list:
@@ -170,6 +173,7 @@ def _get_all(self, test_cases, test_results) -> pd.DataFrame:
170173
return self._sort_all_df(all_df, self._get_all_metric_names(test_results))
171174

172175
def _save_all(self):
176+
# pylint: disable=E1101
173177
all_df = copy.deepcopy(self.all_df)
174178
all_df.index = pd.np.arange(1, len(all_df) + 1)
175179
all_df.to_csv(self.all_rank_file, index_label="rank", encoding="utf-8", sep=" ")
@@ -199,10 +203,21 @@ def _get_selected(self, test_cases, test_results) -> pd.DataFrame:
199203
return selected_df
200204

201205
def _save_selected(self, test_cases, test_results):
206+
# pylint: disable=E1101
202207
selected_df = self._get_selected(test_cases, test_results)
203208
selected_df.index = pd.np.arange(1, len(selected_df) + 1)
204209
selected_df.to_csv(self.selected_rank_file, index_label="rank", encoding="utf-8", sep=" ")
205210

211+
def _draw_pictures(self, test_cases, test_results):
212+
# pylint: disable=E1101
213+
for test_case in test_cases:
214+
out_put = test_case.output_dir
215+
test_result = test_results[test_case.id][0]
216+
matrix = test_result.get('Matrix')
217+
#print(out_put)
218+
for key in matrix.keys():
219+
draw_heatmap_picture(out_put, key, matrix[key])
220+
206221
def _prepare(self, test_cases, test_results, output_dir):
207222
all_metric_names = self._get_all_metric_names(test_results)
208223
all_hps_names = self._get_all_hps_names(test_cases)
@@ -241,6 +256,11 @@ def save(self, test_cases, test_results, output_dir):
241256
if self.save_mode == "selected_only":
242257
self._save_selected(test_cases, test_results)
243258

259+
if self.save_mode == "selected_and_all_and_picture":
260+
self._save_all()
261+
self._save_selected(test_cases, test_results)
262+
self._draw_pictures(test_cases, test_results)
263+
244264
def plot(self):
245265
"""
246266
plot rank according to the visual method, include

core/storymanager/visualization/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
# limitations under the License.
1414

1515
# pylint: disable=missing-module-docstring
16-
from .visualization import get_visualization_func
16+
from .visualization import get_visualization_func, draw_heatmap_picture

core/storymanager/visualization/visualization.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
"""Visualization"""
1616

1717
import sys
18-
18+
import os
19+
import matplotlib.pyplot as plt
1920
from prettytable import from_csv
2021

2122

@@ -25,6 +26,23 @@ def print_table(rank_file):
2526
table = from_csv(file)
2627
print(table)
2728

29+
def draw_heatmap_picture(output, title, matrix):
30+
"""
31+
draw heatmap for results
32+
"""
33+
plt.figure(figsize=(10, 8), dpi=80)
34+
plt.imshow(matrix, cmap='bwr', extent=(0.5, len(matrix)+0.5, 0.5, len(matrix)+0.5),
35+
origin='lower')
36+
plt.xticks(fontsize=15)
37+
plt.yticks(fontsize=15)
38+
plt.xlabel('task round', fontsize=15)
39+
plt.ylabel('task', fontsize=15)
40+
plt.title(title, fontsize=15)
41+
plt.colorbar(format='%.2f')
42+
output_dir = os.path.join(output, f"output/{title}-heatmap.png")
43+
#print(output_dir)
44+
plt.savefig(output_dir)
45+
plt.show()
2846

2947
def get_visualization_func(mode):
3048
""" get visualization func """

core/testcasecontroller/algorithm/algorithm.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,19 @@
1717
import copy
1818

1919
from core.common.constant import ParadigmType
20+
from core.common.utils import load_module
2021
from core.testcasecontroller.algorithm.module import Module
2122
from core.testcasecontroller.algorithm.paradigm import (
2223
SingleTaskLearning,
2324
IncrementalLearning,
2425
MultiedgeInference,
26+
LifelongLearning,
2527
)
2628
from core.testcasecontroller.generation_assistant import get_full_combinations
2729

28-
29-
# pylint: disable=too-few-public-methods
3030
class Algorithm:
31+
# pylint: disable=too-many-instance-attributes
32+
# pylint: disable=too-few-public-methods
3133
"""
3234
Algorithm: typical distributed-synergy AI algorithm paradigm.
3335
Notes:
@@ -53,14 +55,20 @@ class Algorithm:
5355
def __init__(self, name, config):
5456
self.name = name
5557
self.paradigm_type: str = ""
58+
self.third_party_packages: list = []
5659
self.incremental_learning_data_setting: dict = {
5760
"train_ratio": 0.8,
5861
"splitting_method": "default"
5962
}
63+
self.lifelong_learning_data_setting: dict = {
64+
"train_ratio": 0.8,
65+
"splitting_method": "default"
66+
}
6067
self.initial_model_url: str = ""
6168
self.modules: list = []
6269
self.modules_list = None
6370
self._parse_config(config)
71+
self._load_third_party_packages()
6472

6573
def paradigm(self, workspace: str, **kwargs):
6674
"""
@@ -93,6 +101,9 @@ def paradigm(self, workspace: str, **kwargs):
93101
if self.paradigm_type == ParadigmType.MULTIEDGE_INFERENCE.value:
94102
return MultiedgeInference(workspace, **config)
95103

104+
if self.paradigm_type == ParadigmType.LIFELONG_LEARNING.value:
105+
return LifelongLearning(workspace, **config)
106+
96107
return None
97108

98109
def _check_fields(self):
@@ -113,6 +124,11 @@ def _check_fields(self):
113124
f"algorithm incremental_learning_data_setting"
114125
f"({self.incremental_learning_data_setting} must be dictionary type.")
115126

127+
if not isinstance(self.lifelong_learning_data_setting, dict):
128+
raise ValueError(
129+
f"algorithm lifelong_learning_data_setting"
130+
f"({self.lifelong_learning_data_setting} must be dictionary type.")
131+
116132
if not isinstance(self.initial_model_url, str):
117133
raise ValueError(
118134
f"algorithm initial_model_url({self.initial_model_url}) must be string type.")
@@ -138,7 +154,7 @@ def _parse_modules_config(cls, config):
138154
for module in modules:
139155
hps_list = module.hyperparameters_list
140156
if not hps_list:
141-
modules_list.append((module.type, None))
157+
modules_list.append((module.type, [module]))
142158
continue
143159

144160
module_list = []
@@ -152,3 +168,16 @@ def _parse_modules_config(cls, config):
152168
module_combinations_list = get_full_combinations(modules_list)
153169

154170
return module_combinations_list
171+
172+
def _load_third_party_packages(self):
173+
if len(self.third_party_packages) == 0:
174+
return
175+
176+
for package in self.third_party_packages:
177+
name = package["name"]
178+
url = package["url"]
179+
try:
180+
load_module(url)
181+
except Exception as err:
182+
raise RuntimeError(f"load third party packages(name={name}, url={url}) failed,"
183+
f" error: {err}.") from err

core/testcasecontroller/algorithm/module/module.py

Lines changed: 34 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from core.testcasecontroller.generation_assistant import get_full_combinations
2424

2525

26+
# pylint: disable=too-few-public-methods
2627
class Module:
2728
"""
2829
Algorithm Module:
@@ -52,8 +53,8 @@ def __init__(self, config):
5253
self.type: str = ""
5354
self.name: str = ""
5455
self.url: str = ""
55-
self.hyperparameters = None
56-
self.hyperparameters_list = None
56+
self.hyperparameters = {}
57+
self.hyperparameters_list = []
5758
self._parse_config(config)
5859

5960
def _check_fields(self):
@@ -71,75 +72,59 @@ def _check_fields(self):
7172
if not isinstance(self.url, str):
7273
raise ValueError(f"module url({self.url}) must be string type.")
7374

74-
def basemodel_func(self):
75+
def get_module_instance(self, module_type):
7576
"""
76-
get basemodel module function of the module.
77+
get function of algorithm module by using module type
78+
79+
Parameters
80+
---------
81+
module_type: string
82+
module type, e.g.: basemodel, hard_example_mining, etc.
7783
7884
Returns
79-
--------
85+
------
8086
function
8187
8288
"""
89+
class_factory_type = ClassType.GENERAL
90+
if module_type in [ModuleType.HARD_EXAMPLE_MINING.value]:
91+
class_factory_type = ClassType.HEM
8392

84-
if not self.url:
85-
raise ValueError(f"url({self.url}) of basemodel module must be provided.")
93+
elif module_type in [ModuleType.TASK_DEFINITION.value,
94+
ModuleType.TASK_RELATIONSHIP_DISCOVERY.value,
95+
ModuleType.TASK_REMODELING.value,
96+
ModuleType.TASK_ALLOCATION.value,
97+
ModuleType.INFERENCE_INTEGRATE.value]:
98+
class_factory_type = ClassType.STP
8699

87-
try:
88-
utils.load_module(self.url)
89-
# pylint: disable=E1134
90-
basemodel = ClassFactory.get_cls(type_name=ClassType.GENERAL,
91-
t_cls_name=self.name)(**self.hyperparameters)
92-
except Exception as err:
93-
raise RuntimeError(f"basemodel module loads class(name={self.name}) failed, "
94-
f"error: {err}.") from err
100+
elif module_type in [ModuleType.TASK_UPDATE_DECISION.value]:
101+
class_factory_type = ClassType.KM
95102

96-
return basemodel
103+
elif module_type in [ModuleType.UNSEEN_TASK_ALLOCATION.value]:
104+
class_factory_type = ClassType.UTP
97105

98-
def hard_example_mining_func(self):
99-
"""
100-
get hard example mining function of the module.
101-
102-
Returns:
103-
--------
104-
function
105-
106-
"""
106+
elif module_type in [ModuleType.UNSEEN_SAMPLE_RECOGNITION.value,
107+
ModuleType.UNSEEN_SAMPLE_RE_RECOGNITION.value]:
108+
class_factory_type = ClassType.UTD
107109

108110
if self.url:
109111
try:
110112
utils.load_module(self.url)
111113
# pylint: disable=E1134
112114
func = ClassFactory.get_cls(
113-
type_name=ClassType.HEM, t_cls_name=self.name)(**self.hyperparameters)
115+
type_name=class_factory_type, t_cls_name=self.name)(**self.hyperparameters)
114116

115117
return func
116118
except Exception as err:
117-
raise RuntimeError(f"hard_example_mining module loads class"
118-
f"(name={self.name}) failed, error: {err}.") from err
119+
raise RuntimeError(f"module(type={module_type} loads class(name={self.name}) "
120+
f"failed, error: {err}.") from err
119121

120-
# call built-in hard example mining function
121-
hard_example_mining = {"method": self.name}
122+
# call lib built-in module function
123+
module_func = {"method": self.name}
122124
if self.hyperparameters:
123-
hard_example_mining["param"] = self.hyperparameters
124-
125-
return hard_example_mining
125+
module_func["param"] = self.hyperparameters
126126

127-
def get_module_func(self, module_type):
128-
"""
129-
get function of algorithm module by using module type
130-
131-
Parameters
132-
---------
133-
module_type: string
134-
module type, e.g.: basemodel, hard_example_mining, etc.
135-
136-
Returns
137-
------
138-
function
139-
140-
"""
141-
func_name = f"{module_type}_func"
142-
return getattr(self, func_name)
127+
return module_func
143128

144129
def _parse_config(self, config):
145130
# pylint: disable=C0103

core/testcasecontroller/algorithm/paradigm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
from .incremental_learning import IncrementalLearning
1717
from .singletask_learning import SingleTaskLearning
1818
from .multiedge_inference import MultiedgeInference
19+
from .lifelong_learning import LifelongLearning

0 commit comments

Comments
 (0)