Skip to content

Commit ff4c86a

Browse files
maronuuyutaro-oguri
andauthored
Use pickle instead of serialize (#29)
* Use pickle instead of serialize * Fix path to pkl * Add func to generate pickle path * Fmt * Fix target * Fix target * Remove unused import * Change pkl path * Update example * fmt --------- Co-authored-by: yutaro-oguri <[email protected]>
1 parent b8ecfc8 commit ff4c86a

File tree

3 files changed

+24
-18
lines changed

3 files changed

+24
-18
lines changed

example/run_child.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,17 @@
44
import fire
55
import gokart
66
import luigi
7-
# Import task definition
8-
import tasks # noqa: F401
7+
from gokart.target import make_target
98

109
logging.basicConfig(level=logging.INFO)
1110

1211

13-
def main(serialized_task: str) -> None:
12+
def main(task_pkl_path: str) -> None:
1413
# Load luigi config
1514
luigi.configuration.LuigiConfigParser.add_config_path("./conf/base.ini")
1615

1716
# Parse a serialized gokart.TaskOnKart
18-
task: gokart.TaskOnKart = gokart.TaskInstanceParameter().parse(serialized_task)
17+
task: gokart.TaskOnKart = make_target(task_pkl_path).load()
1918

2019
# Run gokart.build
2120
gokart.build(task)

kannon/master.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Deque, Dict, List, Optional, Set
77

88
import gokart
9+
from gokart.target import make_target
910
from kubernetes import client
1011
from luigi.task import flatten
1112

@@ -107,25 +108,27 @@ def _exec_gokart_task(self, task: gokart.TaskOnKart) -> None:
107108
raise RuntimeError(f"Task {self._gen_task_info(task)} on job master has failed.")
108109

109110
def _exec_bullet_task(self, task: TaskOnBullet) -> None:
111+
# Save task instance as pickle object
112+
pkl_path = self._gen_pkl_path(task)
113+
make_target(pkl_path).dump(task)
110114
# Run on child job
111-
serialized_task = gokart.TaskInstanceParameter().serialize(task)
112115
job_name = gen_job_name(f"{self.job_prefix}-{task.get_task_family()}")
113116
job = self._create_child_job_object(
114117
job_name=job_name,
115-
serialized_task=serialized_task,
118+
task_pkl_path=pkl_path,
116119
)
117120
create_job(self.api_instance, job, self.namespace)
118121
logger.info(f"Created child job {job_name} with task {self._gen_task_info(task)}")
119122
task_unique_id = task.make_unique_id()
120123
self.task_id_to_job_name[task_unique_id] = job_name
121124

122-
def _create_child_job_object(self, job_name: str, serialized_task: str) -> client.V1Job:
125+
def _create_child_job_object(self, job_name: str, task_pkl_path: str) -> client.V1Job:
123126
# TODO: use python -c to avoid dependency to execute_task.py
124127
cmd = [
125128
"python",
126129
self.path_child_script,
127-
"--serialized-task",
128-
f"'{serialized_task}'",
130+
"--task-pkl-path",
131+
f"'{task_pkl_path}'",
129132
]
130133
job = deepcopy(self.template_job)
131134
# replace command
@@ -150,6 +153,10 @@ def _create_child_job_object(self, job_name: str, serialized_task: str) -> clien
150153
def _gen_task_info(task: gokart.TaskOnKart) -> str:
151154
return f"{task.get_task_family()}_{task.make_unique_id()}"
152155

156+
@staticmethod
157+
def _gen_pkl_path(task: gokart.TaskOnKart) -> str:
158+
return os.path.join(task.workspace_directory, 'kannon', f'task_obj_{task.make_unique_id()}.pkl')
159+
153160
def _is_executable(self, task: gokart.TaskOnKart) -> bool:
154161
children = flatten(task.requires())
155162

test/test_master.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def test_success_basic(self) -> None:
7676
class Example(gokart.TaskOnKart):
7777
pass
7878

79-
serialized_task = gokart.TaskInstanceParameter().serialize(Example())
79+
path_to_pkl = "path/to/obj"
8080
template_job = self._get_template_job()
8181
master = Kannon(
8282
api_instance=None,
@@ -88,7 +88,7 @@ class Example(gokart.TaskOnKart):
8888
# set os env
8989
os.environ.update({"TASK_WORKSPACE_DIRECTORY": "/cache"})
9090
child_job_name = "test-job"
91-
child_job = master._create_child_job_object(child_job_name, serialized_task)
91+
child_job = master._create_child_job_object(child_job_name, path_to_pkl)
9292

9393
# following should be copied from template_job
9494
self.assertEqual(child_job.api_version, template_job.api_version)
@@ -99,7 +99,7 @@ class Example(gokart.TaskOnKart):
9999
self.assertEqual(child_job.spec.template.spec.containers[0].image, template_job.spec.template.spec.containers[0].image)
100100
self.assertEqual(child_job.spec.template.spec.restart_policy, template_job.spec.template.spec.restart_policy)
101101
# following should be overwritten
102-
self.assertEqual(child_job.spec.template.spec.containers[0].command, ["python", __file__, "--serialized-task", f"'{serialized_task}'"])
102+
self.assertEqual(child_job.spec.template.spec.containers[0].command, ["python", __file__, "--task-pkl-path", f"'{path_to_pkl}'"])
103103
self.assertEqual(child_job.metadata.name, child_job_name)
104104
# envvar TASK_WORKSPACE_DIRECTORY should be inherited
105105
child_env = child_job.spec.template.spec.containers[0].env
@@ -111,7 +111,7 @@ def test_success_custom_env(self) -> None:
111111
class Example(gokart.TaskOnKart):
112112
pass
113113

114-
serialized_task = gokart.TaskInstanceParameter().serialize(Example())
114+
path_to_pkl = "path/to/obj"
115115
template_job = self._get_template_job()
116116
master = Kannon(
117117
api_instance=None,
@@ -123,7 +123,7 @@ class Example(gokart.TaskOnKart):
123123
# set os env
124124
os.environ.update({"TASK_WORKSPACE_DIRECTORY": "/cache", "MY_ENV0": "env0", "MY_ENV1": "env1"})
125125
child_job_name = "test-job"
126-
child_job = master._create_child_job_object(child_job_name, serialized_task)
126+
child_job = master._create_child_job_object(child_job_name, path_to_pkl)
127127

128128
child_env = child_job.spec.template.spec.containers[0].env
129129
self.assertEqual(len(child_env), 3)
@@ -136,7 +136,7 @@ def test_fail_command_set(self) -> None:
136136
class Example(gokart.TaskOnKart):
137137
pass
138138

139-
serialized_task = gokart.TaskInstanceParameter().serialize(Example())
139+
path_to_pkl = "path/to/obj"
140140
template_job = self._get_template_job()
141141
template_job.spec.template.spec.containers[0].command = ["dummy-command"]
142142
master = Kannon(
@@ -149,14 +149,14 @@ class Example(gokart.TaskOnKart):
149149
# set os env
150150
os.environ.update({"TASK_WORKSPACE_DIRECTORY": "/cache"})
151151
with self.assertRaises(AssertionError):
152-
master._create_child_job_object("test-job", serialized_task)
152+
master._create_child_job_object("test-job", path_to_pkl)
153153

154154
def test_fail_default_env_not_exist(self) -> None:
155155

156156
class Example(gokart.TaskOnKart):
157157
pass
158158

159-
serialized_task = gokart.TaskInstanceParameter().serialize(Example())
159+
path_to_pkl = "path/to/obj"
160160
template_job = self._get_template_job()
161161

162162
cases = [None, ["TASK_WORKSPACE_DIRECTORY", "MY_ENV0", "MY_ENV1"]]
@@ -170,7 +170,7 @@ class Example(gokart.TaskOnKart):
170170
env_to_inherit=case,
171171
)
172172
with self.assertRaises(ValueError):
173-
master._create_child_job_object("test-job", serialized_task)
173+
master._create_child_job_object("test-job", path_to_pkl)
174174

175175

176176
if __name__ == '__main__':

0 commit comments

Comments
 (0)