Skip to content

Commit 18ef5ab

Browse files
committed
moved mlflow into run and added tutorial
1 parent 0fefac7 commit 18ef5ab

File tree

4 files changed

+268
-14
lines changed

4 files changed

+268
-14
lines changed

cool_graph/runners.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,6 @@ def __init__(
362362
for k, v in kwargs.items():
363363
setattr(self, k, v)
364364

365-
if self.cfg["logging"].get("use_mlflow", False):
366-
setup_mlflow_from_config(cfg["logging"]["mlflow"])
367-
368365
def sample_data(self) -> None:
369366
"""
370367
Sampling data into batches and sampling data with NeighborLoader into list loaders.
@@ -521,7 +518,10 @@ def run(self, train_loader=None, test_loader=None) -> Dict[str, float]:
521518
"""
522519
self.train_loader = train_loader
523520
self.test_loader = test_loader
524-
521+
522+
if self.cfg["logging"].get("use_mlflow", False):
523+
setup_mlflow_from_config(self.cfg["logging"]["mlflow"])
524+
525525
if (self.train_loader is None) and (self.test_loader is None):
526526
self.sample_data()
527527
elif "index" not in self.data.keys:
@@ -800,7 +800,10 @@ def optimize_run(
800800
Returns:
801801
trials_dataset (pd.DataFrame): Result dataframe with trial params.
802802
"""
803-
803+
804+
if self.cfg["logging"].get("use_mlflow", False):
805+
setup_mlflow_from_config(self.cfg["logging"]["mlflow"])
806+
804807
self.train_loader = train_loader
805808
self.test_loader = test_loader
806809

cool_graph/train/trainer.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,10 @@ def train(self, start_epoch: int = 0, end_epoch: Optional[int] = None) -> Dict[
251251
embedding_data=self.embedding_data,
252252
)
253253
test_tasks = test_metric["tasks"]
254-
self.mlflow_log_metrics(
255-
metrics=add_prefix_to_dict_keys(test_metric, "test_"), step=epoch
256-
)
254+
for task in test_tasks:
255+
self.mlflow_log_metrics(
256+
metrics=add_prefix_to_dict_keys(test_tasks[task], f"{task}_test_"), step=epoch
257+
)
257258
test_metric["epoch"] = epoch
258259
self._test_metric_lst.append(test_metric)
259260
with open(
@@ -280,9 +281,10 @@ def train(self, start_epoch: int = 0, end_epoch: Optional[int] = None) -> Dict[
280281
embedding_data=self.embedding_data,
281282
)
282283
train_tasks = train_metric["tasks"]
283-
self.mlflow_log_metrics(
284-
metrics=add_prefix_to_dict_keys(train_metric, "train_"), step=epoch
285-
)
284+
for task in train_tasks:
285+
self.mlflow_log_metrics(
286+
metrics=add_prefix_to_dict_keys(train_tasks[task], f"{task}_train_"), step=epoch
287+
)
286288
train_metric["epoch"] = epoch
287289
self._train_metric_lst.append(train_metric)
288290
with open(
@@ -339,9 +341,10 @@ def train(self, start_epoch: int = 0, end_epoch: Optional[int] = None) -> Dict[
339341
test_metric = pd.DataFrame(self._test_metric_lst)
340342
train_metric = pd.DataFrame(self._train_metric_lst)
341343

342-
self.mlflow_log_metrics(
343-
metrics=add_prefix_to_dict_keys(self._best_loss, "best_")
344-
)
344+
for task in self._best_loss["tasks"]:
345+
self.mlflow_log_metrics(
346+
metrics=add_prefix_to_dict_keys(self._best_loss["tasks"][task], f"{task}_best_")
347+
)
345348
self.mlflow_log_metrics({"global_calc_time": self.global_calc_time})
346349

347350
if self.use_mlflow:
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "e81c847e",
6+
"metadata": {},
7+
"source": [
8+
"# Integration with mlflow"
9+
]
10+
},
11+
{
12+
"cell_type": "markdown",
13+
"id": "b6aa588c",
14+
"metadata": {},
15+
"source": [
16+
"### Summary: \n",
17+
"##### Let's change config parameters to track training with mlflow"
18+
]
19+
},
20+
{
21+
"cell_type": "code",
22+
"execution_count": 3,
23+
"id": "35d5f728",
24+
"metadata": {
25+
"execution": {
26+
"iopub.execute_input": "2024-08-21T10:36:36.576223Z",
27+
"iopub.status.busy": "2024-08-21T10:36:36.575851Z",
28+
"iopub.status.idle": "2024-08-21T10:36:36.579517Z",
29+
"shell.execute_reply": "2024-08-21T10:36:36.579010Z",
30+
"shell.execute_reply.started": "2024-08-21T10:36:36.576205Z"
31+
}
32+
},
33+
"outputs": [],
34+
"source": [
35+
"# imports for loading the dataset\n",
36+
"from torch_geometric import datasets\n",
37+
"import torch\n",
38+
"import pandas as pd\n",
39+
"from torch_geometric.data import Data\n",
40+
"# importing Runner\n",
41+
"from cool_graph.runners import Runner"
42+
]
43+
},
44+
{
45+
"cell_type": "code",
46+
"execution_count": 6,
47+
"id": "d82a914c",
48+
"metadata": {
49+
"execution": {
50+
"iopub.execute_input": "2024-08-21T10:37:07.692657Z",
51+
"iopub.status.busy": "2024-08-21T10:37:07.692267Z",
52+
"iopub.status.idle": "2024-08-21T10:37:29.043296Z",
53+
"shell.execute_reply": "2024-08-21T10:37:29.042582Z",
54+
"shell.execute_reply.started": "2024-08-21T10:37:07.692635Z"
55+
}
56+
},
57+
"outputs": [
58+
{
59+
"name": "stderr",
60+
"output_type": "stream",
61+
"text": [
62+
"Downloading https://github.com/shchur/gnn-benchmark/raw/master/data/npz/amazon_electronics_computers.npz\n",
63+
"Processing...\n",
64+
"Done!\n"
65+
]
66+
},
67+
{
68+
"data": {
69+
"text/plain": [
70+
"Data(x=[13752, 767], edge_index=[2, 491722], y=[13752])"
71+
]
72+
},
73+
"execution_count": 6,
74+
"metadata": {},
75+
"output_type": "execute_result"
76+
}
77+
],
78+
"source": [
79+
"# use simple Amazon dataset with Computers\n",
80+
"dataset = datasets.Amazon(root='./data/Amazon', name='Computers')\n",
81+
"data = dataset.data\n",
82+
"data"
83+
]
84+
},
85+
{
86+
"cell_type": "code",
87+
"execution_count": 7,
88+
"id": "3faa80be",
89+
"metadata": {
90+
"execution": {
91+
"iopub.execute_input": "2024-08-21T10:39:13.803082Z",
92+
"iopub.status.busy": "2024-08-21T10:39:13.802561Z",
93+
"iopub.status.idle": "2024-08-21T10:39:13.951912Z",
94+
"shell.execute_reply": "2024-08-21T10:39:13.950993Z",
95+
"shell.execute_reply.started": "2024-08-21T10:39:13.803052Z"
96+
}
97+
},
98+
"outputs": [],
99+
"source": [
100+
"# initializing Runner\n",
101+
"runner = Runner(data)"
102+
]
103+
},
104+
{
105+
"cell_type": "code",
106+
"execution_count": 8,
107+
"id": "c0c84614",
108+
"metadata": {
109+
"execution": {
110+
"iopub.execute_input": "2024-08-21T10:55:27.607989Z",
111+
"iopub.status.busy": "2024-08-21T10:55:27.607408Z",
112+
"iopub.status.idle": "2024-08-21T10:55:27.612779Z",
113+
"shell.execute_reply": "2024-08-21T10:55:27.612280Z",
114+
"shell.execute_reply.started": "2024-08-21T10:55:27.607965Z"
115+
}
116+
},
117+
"outputs": [],
118+
"source": [
119+
"import urllib3\n",
120+
"\n",
121+
"runner.cfg[\"logging\"][\"use_mlflow\"] = True # making flag True to use mlflow\n",
122+
"runner.cfg[\"logging\"][\"mlflow\"] = {\n",
123+
" \"MLFLOW_TRACKING_URI\": \"https://ml-flow.msk.bd-cloud.mts.ru/\", # uri of \n",
124+
" \"MLFLOW_TRACKING_USERNAME\": \"username\",\n",
125+
" \"MLFLOW_TRACKING_PASSWORD\": \"password\",\n",
126+
" \"MLFLOW_S3_ENDPOINT_URL\": \"https://s3.mts-corp.ru\", # to save artifacts\n",
127+
" \"AWS_ACCESS_KEY_ID\": \"access_key\", # to save artifacts\n",
128+
" \"AWS_SECRET_ACCESS_KEY\": \"secret_access_key\", # to save artifacts\n",
129+
" \"MLFLOW_TRACKING_INSECURE_TLS\": \"true\", # to ignore the TLS certificate verification\n",
130+
" \"MLFLOW_S3_IGNORE_TLS\": \"true\", # to ignore the TLS certificate verification\n",
131+
" \"MLFLOW_DISABLE_INSECURE_REQUEST_WARNING\": True # to disable warnings\n",
132+
" }\n",
133+
"runner.cfg[\"logging\"][\"mlflow_experiment_name\"] = \"coolgraph_example\" # name of experiment\n",
134+
"\n",
135+
"urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) # disabling request warning"
136+
]
137+
},
138+
{
139+
"cell_type": "code",
140+
"execution_count": 9,
141+
"id": "f154a32f",
142+
"metadata": {
143+
"execution": {
144+
"iopub.execute_input": "2024-08-21T10:55:36.670982Z",
145+
"iopub.status.busy": "2024-08-21T10:55:36.670511Z",
146+
"iopub.status.idle": "2024-08-21T10:56:53.936843Z",
147+
"shell.execute_reply": "2024-08-21T10:56:53.935934Z",
148+
"shell.execute_reply.started": "2024-08-21T10:55:36.670959Z"
149+
}
150+
},
151+
"outputs": [
152+
{
153+
"name": "stderr",
154+
"output_type": "stream",
155+
"text": [
156+
"Sample data: 100%|██████████| 42/42 [00:04<00:00, 10.12it/s]\n",
157+
"Sample data: 100%|██████████| 14/14 [00:00<00:00, 14.57it/s]\n",
158+
"2024/08/21 13:55:44 INFO mlflow.tracking.fluent: Experiment with name 'coolgraph_example' does not exist. Creating a new experiment.\n",
159+
"2024-08-21 13:55:51 - epoch 0 test: \n",
160+
" {'accuracy': 0.505, 'cross_entropy': 1.31, 'f1_weighted': 0.451, 'calc_time': 0.006, 'main_metric': 0.505}\n",
161+
"2024-08-21 13:55:53 - epoch 0 train: \n",
162+
" {'accuracy': 0.497, 'cross_entropy': 1.313, 'f1_weighted': 0.442, 'calc_time': 0.016, 'main_metric': 0.497}\n",
163+
"2024-08-21 13:56:06 - epoch 5 test: \n",
164+
" {'accuracy': 0.9, 'cross_entropy': 0.299, 'f1_weighted': 0.899, 'calc_time': 0.006, 'main_metric': 0.9}\n",
165+
"2024-08-21 13:56:07 - epoch 5 train: \n",
166+
" {'accuracy': 0.919, 'cross_entropy': 0.246, 'f1_weighted': 0.918, 'calc_time': 0.012, 'main_metric': 0.919}\n",
167+
"2024-08-21 13:56:20 - epoch 10 test: \n",
168+
" {'accuracy': 0.918, 'cross_entropy': 0.268, 'f1_weighted': 0.918, 'calc_time': 0.006, 'main_metric': 0.918}\n",
169+
"2024-08-21 13:56:21 - epoch 10 train: \n",
170+
" {'accuracy': 0.953, 'cross_entropy': 0.156, 'f1_weighted': 0.953, 'calc_time': 0.014, 'main_metric': 0.953}\n",
171+
"2024-08-21 13:56:34 - epoch 15 test: \n",
172+
" {'accuracy': 0.92, 'cross_entropy': 0.277, 'f1_weighted': 0.92, 'calc_time': 0.005, 'main_metric': 0.92}\n",
173+
"2024-08-21 13:56:35 - epoch 15 train: \n",
174+
" {'accuracy': 0.961, 'cross_entropy': 0.124, 'f1_weighted': 0.961, 'calc_time': 0.011, 'main_metric': 0.961}\n",
175+
"2024-08-21 13:56:48 - epoch 20 test: \n",
176+
" {'accuracy': 0.921, 'cross_entropy': 0.296, 'f1_weighted': 0.92, 'calc_time': 0.008, 'main_metric': 0.921}\n",
177+
"2024-08-21 13:56:50 - epoch 20 train: \n",
178+
" {'accuracy': 0.969, 'cross_entropy': 0.098, 'f1_weighted': 0.969, 'calc_time': 0.012, 'main_metric': 0.969}\n"
179+
]
180+
}
181+
],
182+
"source": [
183+
"result = runner.run()"
184+
]
185+
},
186+
{
187+
"cell_type": "markdown",
188+
"id": "c835b041",
189+
"metadata": {},
190+
"source": [
191+
"### Let's see the results on Mlflow tracker"
192+
]
193+
},
194+
{
195+
"cell_type": "markdown",
196+
"id": "c67c3798",
197+
"metadata": {
198+
"execution": {
199+
"iopub.execute_input": "2024-08-21T11:08:43.663318Z",
200+
"iopub.status.busy": "2024-08-21T11:08:43.662778Z",
201+
"iopub.status.idle": "2024-08-21T11:08:43.939813Z",
202+
"shell.execute_reply": "2024-08-21T11:08:43.938751Z",
203+
"shell.execute_reply.started": "2024-08-21T11:08:43.663287Z"
204+
}
205+
},
206+
"source": [
207+
"![mlflow_result](./src/image_2024-08-21_13-59-13.png) "
208+
]
209+
},
210+
{
211+
"cell_type": "markdown",
212+
"id": "cbac9ca5",
213+
"metadata": {},
214+
"source": [
215+
"### Success!"
216+
]
217+
},
218+
{
219+
"cell_type": "code",
220+
"execution_count": null,
221+
"id": "686924be",
222+
"metadata": {},
223+
"outputs": [],
224+
"source": []
225+
}
226+
],
227+
"metadata": {
228+
"kernelspec": {
229+
"display_name": "CGKerner",
230+
"language": "python",
231+
"name": "cgkerner"
232+
},
233+
"language_info": {
234+
"codemirror_mode": {
235+
"name": "ipython",
236+
"version": 3
237+
},
238+
"file_extension": ".py",
239+
"mimetype": "text/x-python",
240+
"name": "python",
241+
"nbconvert_exporter": "python",
242+
"pygments_lexer": "ipython3",
243+
"version": "3.8.13"
244+
}
245+
},
246+
"nbformat": 4,
247+
"nbformat_minor": 5
248+
}
180 KB
Loading

0 commit comments

Comments
 (0)