Skip to content

Commit 0ac20e8

Browse files
committed
feat: add automatic processor selection based on DataFrame type parameter
1 parent 8826a52 commit 0ac20e8

File tree

3 files changed

+193
-3
lines changed

3 files changed

+193
-3
lines changed

gokart/task.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from gokart.required_task_output import RequiredTaskOutput
2626
from gokart.target import TargetOnKart
2727
from gokart.task_complete_check import task_complete_check_wrapper
28-
from gokart.utils import FlattenableItems, flatten, map_flattenable_items
28+
from gokart.utils import FlattenableItems, flatten, get_dataframe_type_from_task, map_flattenable_items
2929

3030
logger = getLogger(__name__)
3131

@@ -219,6 +219,10 @@ def make_target(self, relative_file_path: str | None = None, use_unique_id: bool
219219
file_path = os.path.join(self.workspace_directory, formatted_relative_file_path)
220220
unique_id = self.make_unique_id() if use_unique_id else None
221221

222+
# Auto-select processor based on type parameter if not provided
223+
if processor is None and relative_file_path is not None:
224+
processor = self._create_processor_for_dataframe_type(file_path)
225+
222226
task_lock_params = make_task_lock_params(
223227
file_path=file_path,
224228
unique_id=unique_id,
@@ -232,6 +236,39 @@ def make_target(self, relative_file_path: str | None = None, use_unique_id: bool
232236
file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather
233237
)
234238

239+
def _create_processor_for_dataframe_type(self, file_path: str) -> FileProcessor | None:
240+
"""
241+
Create a file processor with appropriate return_type based on task's type parameter.
242+
243+
Args:
244+
file_path: Path to the file
245+
246+
Returns:
247+
FileProcessor with return_type set, or None to use default processor
248+
"""
249+
from gokart.file_processor import CsvFileProcessor, FeatherFileProcessor, JsonFileProcessor, ParquetFileProcessor
250+
251+
extension = os.path.splitext(file_path)[1]
252+
df_type = get_dataframe_type_from_task(self)
253+
254+
# Create custom processor for DataFrame-supporting file types with type parameter
255+
if extension == '.csv':
256+
return CsvFileProcessor(sep=',', return_type=df_type)
257+
elif extension == '.tsv':
258+
return CsvFileProcessor(sep='\t', return_type=df_type)
259+
elif extension == '.json':
260+
return JsonFileProcessor(orient=None, return_type=df_type)
261+
elif extension == '.ndjson':
262+
return JsonFileProcessor(orient='records', return_type=df_type)
263+
elif extension == '.parquet':
264+
return ParquetFileProcessor(return_type=df_type)
265+
elif extension == '.feather':
266+
# Note: store_index_in_feather is a required parameter, defaulting to False
267+
return FeatherFileProcessor(store_index_in_feather=False, return_type=df_type)
268+
269+
# For other file types, use default processor selection
270+
return None
271+
235272
def make_large_data_frame_target(self, relative_file_path: str | None = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart:
236273
formatted_relative_file_path = (
237274
relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.zip')

gokart/utils.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
from collections.abc import Callable, Iterable
55
from io import BytesIO
6-
from typing import Any, Protocol, TypeAlias, TypeVar
6+
from typing import Any, Literal, Protocol, TypeAlias, TypeVar, get_args, get_origin
77

88
import dill
99
import luigi
@@ -92,3 +92,48 @@ def load_dill_with_pandas_backward_compatibility(file: FileLike | BytesIO) -> An
9292
assert file.seekable(), f'{file} is not seekable.'
9393
file.seek(0)
9494
return pd.read_pickle(file)
95+
96+
97+
def get_dataframe_type_from_task(task: Any) -> Literal['pandas', 'polars']:
98+
"""
99+
Extract DataFrame type from TaskOnKart[T] type parameter.
100+
101+
Examines the type parameter T of a TaskOnKart subclass to determine
102+
whether it uses pandas or polars DataFrames.
103+
104+
Args:
105+
task: A TaskOnKart instance or class
106+
107+
Returns:
108+
'pandas' or 'polars' (defaults to 'pandas' if type cannot be determined)
109+
110+
Examples:
111+
>>> class MyTask(TaskOnKart[pd.DataFrame]): pass
112+
>>> get_dataframe_type_from_task(MyTask())
113+
'pandas'
114+
115+
>>> class MyPolarsTask(TaskOnKart[pl.DataFrame]): pass
116+
>>> get_dataframe_type_from_task(MyPolarsTask())
117+
'polars'
118+
"""
119+
task_class = task if isinstance(task, type) else task.__class__
120+
121+
if not hasattr(task_class, '__orig_bases__'):
122+
return 'pandas'
123+
124+
for base in task_class.__orig_bases__:
125+
origin = get_origin(base)
126+
# Check if this is a TaskOnKart subclass
127+
if origin and hasattr(origin, '__name__') and origin.__name__ == 'TaskOnKart':
128+
args = get_args(base)
129+
if args:
130+
df_type = args[0]
131+
module = getattr(df_type, '__module__', '')
132+
133+
# Check module name to determine DataFrame type
134+
if 'polars' in module:
135+
return 'polars'
136+
elif 'pandas' in module:
137+
return 'pandas'
138+
139+
return 'pandas' # Default to pandas for backward compatibility

test/test_utils.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
import unittest
22

3-
from gokart.utils import flatten, map_flattenable_items
3+
import pandas as pd
4+
import pytest
5+
6+
from gokart.task import TaskOnKart
7+
from gokart.utils import flatten, get_dataframe_type_from_task, map_flattenable_items
8+
9+
try:
10+
import polars as pl
11+
12+
HAS_POLARS = True
13+
except ImportError:
14+
HAS_POLARS = False
415

516

617
class TestFlatten(unittest.TestCase):
@@ -34,3 +45,100 @@ def test_map_flattenable_items(self):
3445
),
3546
{'a': ['1', '2', '3', '4'], 'b': {'c': 'True', 'd': {'e': '5'}}},
3647
)
48+
49+
50+
class TestGetDataFrameTypeFromTask(unittest.TestCase):
51+
"""Tests for get_dataframe_type_from_task function."""
52+
53+
def test_pandas_dataframe_from_instance(self):
54+
"""Test detecting pandas DataFrame from task instance."""
55+
56+
class PandasTask(TaskOnKart[pd.DataFrame]):
57+
pass
58+
59+
task = PandasTask()
60+
self.assertEqual(get_dataframe_type_from_task(task), 'pandas')
61+
62+
def test_pandas_dataframe_from_class(self):
63+
"""Test detecting pandas DataFrame from task class."""
64+
65+
class PandasTask(TaskOnKart[pd.DataFrame]):
66+
pass
67+
68+
self.assertEqual(get_dataframe_type_from_task(PandasTask), 'pandas')
69+
70+
@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
71+
def test_polars_dataframe_from_instance(self):
72+
"""Test detecting polars DataFrame from task instance."""
73+
74+
class PolarsTask(TaskOnKart[pl.DataFrame]):
75+
pass
76+
77+
task = PolarsTask()
78+
self.assertEqual(get_dataframe_type_from_task(task), 'polars')
79+
80+
@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
81+
def test_polars_dataframe_from_class(self):
82+
"""Test detecting polars DataFrame from task class."""
83+
84+
class PolarsTask(TaskOnKart[pl.DataFrame]):
85+
pass
86+
87+
self.assertEqual(get_dataframe_type_from_task(PolarsTask), 'polars')
88+
89+
def test_no_type_parameter_defaults_to_pandas(self):
90+
"""Test that tasks without type parameter default to pandas."""
91+
92+
# Create a class without __orig_bases__ by not using type parameters
93+
class PlainTask:
94+
pass
95+
96+
task = PlainTask()
97+
self.assertEqual(get_dataframe_type_from_task(task), 'pandas')
98+
99+
def test_non_taskonkart_class_defaults_to_pandas(self):
100+
"""Test that non-TaskOnKart classes default to pandas."""
101+
102+
class RegularClass:
103+
pass
104+
105+
task = RegularClass()
106+
self.assertEqual(get_dataframe_type_from_task(task), 'pandas')
107+
108+
def test_taskonkart_with_non_dataframe_type(self):
109+
"""Test TaskOnKart with non-DataFrame type parameter defaults to pandas."""
110+
111+
class StringTask(TaskOnKart[str]):
112+
pass
113+
114+
task = StringTask()
115+
# Should default to pandas since str module is not 'pandas' or 'polars'
116+
self.assertEqual(get_dataframe_type_from_task(task), 'pandas')
117+
118+
def test_nested_inheritance_pandas(self):
119+
"""Test that nested inheritance without direct type parameter defaults to pandas."""
120+
121+
class BasePandasTask(TaskOnKart[pd.DataFrame]):
122+
pass
123+
124+
class DerivedPandasTask(BasePandasTask):
125+
pass
126+
127+
task = DerivedPandasTask()
128+
# DerivedPandasTask doesn't have its own __orig_bases__ with type parameter,
129+
# so it defaults to 'pandas'
130+
self.assertEqual(get_dataframe_type_from_task(task), 'pandas')
131+
132+
@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
133+
def test_nested_inheritance_polars(self):
134+
"""Test detecting polars DataFrame type through nested inheritance."""
135+
136+
class BasePolarsTask(TaskOnKart[pl.DataFrame]):
137+
pass
138+
139+
class DerivedPolarsTask(BasePolarsTask):
140+
pass
141+
142+
task = DerivedPolarsTask()
143+
# Function should detect 'polars' through the inheritance chain
144+
self.assertEqual(get_dataframe_type_from_task(task), 'polars')

0 commit comments

Comments
 (0)