Skip to content

Commit f1dac48

Browse files
author
Ping-Han Hsieh
committed
pass expected_output_dataframe_type to make_target
1 parent 62ab1d8 commit f1dac48

File tree

4 files changed

+20
-4
lines changed

4 files changed

+20
-4
lines changed

gokart/target.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,12 +227,13 @@ def make_target(
227227
processor: Optional[FileProcessor] = None,
228228
task_lock_params: Optional[TaskLockParams] = None,
229229
store_index_in_feather: bool = True,
230+
expected_dataframe_type: Optional[pa.DataFrameModel] = None,
230231
) -> TargetOnKart:
231232
_task_lock_params = task_lock_params if task_lock_params is not None else make_task_lock_params(file_path=file_path, unique_id=unique_id)
232233
file_path = _make_file_path(file_path, unique_id)
233234
processor = processor or make_file_processor(file_path, store_index_in_feather=store_index_in_feather)
234235
file_system_target = _make_file_system_target(file_path, processor=processor, store_index_in_feather=store_index_in_feather)
235-
return SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params)
236+
return SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params, expected_dataframe_type=expected_dataframe_type)
236237

237238

238239
def make_model_target(

gokart/task.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import luigi
1010
import pandas as pd
11+
import pandera as pa
1112
from luigi.parameter import ParameterVisibility
1213

1314
import gokart
@@ -83,6 +84,7 @@ class TaskOnKart(luigi.Task):
8384
default=True, description='Check if output file exists at run. If exists, run() will be skipped.', significant=False
8485
)
8586
should_lock_run: bool = ExplicitBoolParameter(default=False, significant=False, description='Whether to use redis lock or not at task run.')
87+
expected_output_dataframe_type: Optional[pa.DataFrameModel] = None
8688

8789
def __init__(self, *args, **kwargs):
8890
self._add_configuration(kwargs, 'TaskOnKart')
@@ -192,7 +194,7 @@ def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: b
192194
)
193195

194196
return gokart.target.make_target(
195-
file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather
197+
file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather, expected_dataframe_type=self.expected_output_dataframe_type
196198
)
197199

198200
def make_large_data_frame_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart:

test/test_target.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import boto3
1010
import luigi
1111
import numpy as np
12-
import pandera as pa
1312
import pandas as pd
13+
import pandera as pa
1414
from matplotlib import pyplot
1515
from moto import mock_s3
1616

test/test_task_on_kart.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import luigi
99
import pandas as pd
10+
import pandera as pa
1011
from luigi.parameter import ParameterVisibility
1112
from luigi.util import inherits
1213

@@ -340,7 +341,19 @@ def test_fail_on_empty_dump(self):
340341
# fail
341342
task = _DummyTask(fail_on_empty_dump=True)
342343
self.assertRaises(AssertionError, lambda: task.dump(pd.DataFrame()))
343-
344+
345+
def test_fail_with_type_check(self):
346+
347+
class _DummyTypeSchema(pa.DataFrameModel):
348+
a: pa.typing.Series[int] = pa.Field()
349+
class _DummyTaskWithType(gokart.TaskOnKart):
350+
expected_output_dataframe_type = _DummyTypeSchema
351+
352+
task = _DummyTaskWithType()
353+
# fail
354+
with self.assertRaises(pa.errors.SchemaError):
355+
task.dump(pd.DataFrame(dict(a=['1', '2', '3'])))
356+
344357
@patch('luigi.configuration.get_config')
345358
def test_add_configuration(self, mock_config: MagicMock):
346359
mock_config.return_value = {'_DummyTask': {'list_param': '["c", "d"]', 'param': '3', 'bool_param': 'True'}}

0 commit comments

Comments
 (0)