Skip to content

Commit 07ddb96

Browse files
Update helper_functions.py
1 parent 7029fc2 commit 07ddb96

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

DimRed/helper_functions.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,29 @@
11
from DimRed import *
2+
3+
4+
def load_dataset(test_split: float = 0.25) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
5+
digits = datasets.load_digits()
6+
X, y = digits.data, digits.target
7+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_split, stratify=y,
8+
random_state=RANDOM_STATE, shuffle=True)
9+
return X_train, X_test, y_train, y_test
10+
11+
12+
def label_encoding(y_train: np.ndarray, y_test: np.ndarray) -> Tuple:
13+
le = LabelEncoder()
14+
y_train = le.fit_transform(y_train)
15+
y_test = le.transform(y_test)
16+
return y_train, y_test
17+
18+
19+
def average_metric(metric, dictionaries: List[Dict[str, Union[str, int]]]) -> float:
20+
avg = 0
21+
for dictionary in dictionaries:
22+
avg += dictionary[metric]
23+
return float(avg / len(dictionaries))
24+
25+
26+
def add_to_dictionary(dictionary: Dict[str, List[Union[str, int]]], list_of_values: List[Union[str, int]]) -> Dict[str, List[Union[str, int]]]:
27+
for idx, key in enumerate(dictionary):
28+
dictionary[key].append(list_of_values[idx])
29+
return dictionary

0 commit comments

Comments
 (0)