Skip to content

Commit 00c846a

Browse files
authored
To pandas - hierarchical multi header (#22)
1 parent 095952c commit 00c846a

File tree

9 files changed

+107
-109
lines changed

9 files changed

+107
-109
lines changed

examples/llm-claude-aggregate-query.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22

33
import anthropic
4-
import pandas as pd
54
from anthropic.types import Message
65

76
from datachain import Column, DataChain
@@ -55,6 +54,4 @@
5554
)
5655
)
5756

58-
with pd.option_context("display.max_columns", None):
59-
df = chain.to_pandas()
60-
print(df)
57+
chain.show()

examples/llm-claude-simple-query.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import os
33

44
import anthropic
5-
import pandas as pd
65
from anthropic.types import Message
76
from pydantic import BaseModel
87

@@ -62,6 +61,4 @@ class Rating(BaseModel):
6261
)
6362
)
6463

65-
with pd.option_context("display.max_columns", None):
66-
df = chain.to_pandas()
67-
print(df)
64+
chain.show()

examples/llm-claude.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22

33
import anthropic
4-
import pandas as pd
54
from anthropic.types import Message
65

76
from datachain import Column, DataChain, File
@@ -37,6 +36,4 @@
3736
)
3837
)
3938

40-
with pd.option_context("display.max_columns", None):
41-
df = chain.to_pandas()
42-
print(df)
39+
chain.show()

src/datachain/lib/dc.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Union,
1212
)
1313

14+
import pandas as pd
1415
import sqlalchemy
1516
from pydantic import BaseModel, create_model
1617

@@ -38,9 +39,9 @@
3839
detach,
3940
)
4041
from datachain.query.schema import Column, DatasetRow
42+
from datachain.utils import inside_notebook
4143

4244
if TYPE_CHECKING:
43-
import pandas as pd
4445
from typing_extensions import Self
4546

4647
C = Column
@@ -731,6 +732,37 @@ def from_pandas( # type: ignore[override]
731732

732733
return cls.from_values(name, session, object_name=object_name, **fr_map)
733734

735+
def to_pandas(self, flatten=False) -> "pd.DataFrame":
736+
headers, max_length = self.signals_schema.get_headers_with_length()
737+
if flatten or max_length < 2:
738+
df = pd.DataFrame.from_records(self.to_records())
739+
if headers:
740+
df.columns = [".".join(filter(None, header)) for header in headers]
741+
return df
742+
743+
transposed_result = list(map(list, zip(*self.results())))
744+
data = {tuple(n): val for n, val in zip(headers, transposed_result)}
745+
return pd.DataFrame(data)
746+
747+
def show(self, limit: int = 20, flatten=False, transpose=False) -> None:
748+
dc = self.limit(limit) if limit > 0 else self
749+
df = dc.to_pandas(flatten)
750+
if transpose:
751+
df = df.T
752+
753+
with pd.option_context(
754+
"display.max_columns", None, "display.multi_sparse", False
755+
):
756+
if inside_notebook():
757+
from IPython.display import display
758+
759+
display(df)
760+
else:
761+
print(df)
762+
763+
if len(df) == limit:
764+
print(f"\n[Limited by {len(df)} rows]")
765+
734766
def parse_tabular(
735767
self,
736768
output: OutputType = None,

src/datachain/lib/signal_schema.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,16 @@ def print_tree(self, indent: int = 4, start_at: int = 0):
338338
sub_schema = SignalSchema({"* list of": args[0]})
339339
sub_schema.print_tree(indent=indent, start_at=total_indent + indent)
340340

341+
def get_headers_with_length(self):
342+
paths = [
343+
path for path, _, has_subtree, _ in self.get_flat_tree() if not has_subtree
344+
]
345+
max_length = max([len(path) for path in paths], default=0)
346+
return [
347+
path + [""] * (max_length - len(path)) if len(path) < max_length else path
348+
for path in paths
349+
], max_length
350+
341351
def __or__(self, other):
342352
return self.__class__(self.values | other.values)
343353

src/datachain/query/dataset.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
)
2727

2828
import attrs
29-
import pandas as pd
3029
import sqlalchemy
3130
from attrs import frozen
3231
from dill import dumps, source
@@ -53,10 +52,9 @@
5352
from datachain.dataset import DatasetStatus, RowDict
5453
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
5554
from datachain.progress import CombinedDownloadCallback
56-
from datachain.query.schema import DEFAULT_DELIMITER
5755
from datachain.sql.functions import rand
5856
from datachain.storage import Storage, StorageURI
59-
from datachain.utils import batched, determine_processes, inside_notebook
57+
from datachain.utils import batched, determine_processes
6058

6159
from .metrics import metrics
6260
from .schema import C, UDFParamSpec, normalize_param
@@ -1346,12 +1344,6 @@ async def get_params(row: RowDict) -> tuple:
13461344
def to_records(self) -> list[dict[str, Any]]:
13471345
return self.results(lambda cols, row: dict(zip(cols, row)))
13481346

1349-
def to_pandas(self) -> "pd.DataFrame":
1350-
records = self.to_records()
1351-
df = pd.DataFrame.from_records(records)
1352-
df.columns = [c.replace(DEFAULT_DELIMITER, ".") for c in df.columns]
1353-
return df
1354-
13551347
def shuffle(self) -> "Self":
13561348
# ToDo: implement shaffle based on seed and/or generating random column
13571349
return self.order_by(C.sys__rand)
@@ -1370,22 +1362,6 @@ def sample(self, n) -> "Self":
13701362

13711363
return sampled.limit(n)
13721364

1373-
def show(self, limit=20) -> None:
1374-
df = self.limit(limit).to_pandas()
1375-
1376-
options = ["display.max_colwidth", 50, "display.show_dimensions", False]
1377-
with pd.option_context(*options):
1378-
if inside_notebook():
1379-
from IPython.display import display
1380-
1381-
display(df)
1382-
1383-
else:
1384-
print(df.to_string())
1385-
1386-
if len(df) == limit:
1387-
print(f"[limited by {limit} objects]")
1388-
13891365
def clone(self, new_table=True) -> "Self":
13901366
obj = copy(self)
13911367
obj.steps = obj.steps.copy()

tests/func/test_dataset_query.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3213,19 +3213,6 @@ def test_to_records(simple_ds_query):
32133213
assert simple_ds_query.to_records() == SIMPLE_DS_QUERY_RECORDS
32143214

32153215

3216-
@pytest.mark.parametrize(
3217-
"cloud_type,version_aware",
3218-
[("s3", True)],
3219-
indirect=True,
3220-
)
3221-
def test_to_pandas(simple_ds_query):
3222-
import pandas as pd
3223-
3224-
df = simple_ds_query.to_pandas()
3225-
expected = pd.DataFrame.from_records(SIMPLE_DS_QUERY_RECORDS)
3226-
assert (df == expected).all(axis=None)
3227-
3228-
32293216
@pytest.mark.parametrize("method", ["to_records", "extract"])
32303217
@pytest.mark.parametrize("save", [True, False])
32313218
@pytest.mark.parametrize(

tests/unit/lib/test_datachain.py

Lines changed: 54 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,8 @@ def test_from_features(catalog):
112112
params="parent",
113113
output={"file": File, "t1": MyFr},
114114
)
115-
df1 = ds.to_pandas()
116-
117-
assert df1[["t1.nnn", "t1.count"]].equals(
118-
pd.DataFrame({"t1.nnn": ["n1", "n2", "n1"], "t1.count": [3, 5, 1]})
119-
)
115+
for i, (_, t1) in enumerate(ds.iterate()):
116+
assert t1 == features[i]
120117

121118

122119
def test_preserve_feature_schema(catalog):
@@ -212,33 +209,33 @@ class _TestFr(BaseModel):
212209
params="t1",
213210
output={"x": _TestFr},
214211
)
212+
# assert ds.collect() == 1
215213

216-
df = ds.to_pandas()
214+
for i, (x,) in enumerate(ds.iterate()):
215+
assert isinstance(x, _TestFr)
217216

218-
assert df["x.my_name"].tolist() == ["n1", "n2", "n1"]
219-
assert np.allclose(df["x.sqrt"], [math.sqrt(x) for x in [3, 5, 1]])
220-
with pytest.raises(KeyError):
221-
df["x.t1.nnn"]
217+
fr = features[i]
218+
test_fr = _TestFr(file=File(name=""), sqrt=math.sqrt(fr.count), my_name=fr.nnn)
219+
assert x == test_fr
222220

223221

224222
def test_map(catalog):
225223
class _TestFr(BaseModel):
226224
sqrt: float
227225
my_name: str
228226

229-
ds = DataChain.from_values(t1=features)
230-
231-
df = ds.map(
227+
dc = DataChain.from_values(t1=features).map(
232228
x=lambda m_fr: _TestFr(
233229
sqrt=math.sqrt(m_fr.count),
234230
my_name=m_fr.nnn + "_suf",
235231
),
236232
params="t1",
237233
output={"x": _TestFr},
238-
).to_pandas()
234+
)
239235

240-
assert df["x.my_name"].tolist() == ["n1_suf", "n2_suf", "n1_suf"]
241-
assert np.allclose(df["x.sqrt"], [math.sqrt(x) for x in [3, 5, 1]])
236+
assert dc.collect_one("x") == [
237+
_TestFr(sqrt=math.sqrt(fr.count), my_name=fr.nnn + "_suf") for fr in features
238+
]
242239

243240

244241
def test_agg(catalog):
@@ -247,26 +244,31 @@ class _TestFr(BaseModel):
247244
cnt: int
248245
my_name: str
249246

250-
df = (
251-
DataChain.from_values(t1=features)
252-
.agg(
253-
x=lambda frs: [
254-
_TestFr(
255-
f=File(name=""),
256-
cnt=sum(f.count for f in frs),
257-
my_name="-".join([fr.nnn for fr in frs]),
258-
)
259-
],
260-
partition_by=C.t1.nnn,
261-
params="t1",
262-
output={"x": _TestFr},
263-
)
264-
.to_pandas()
247+
dc = DataChain.from_values(t1=features).agg(
248+
x=lambda frs: [
249+
_TestFr(
250+
f=File(name=""),
251+
cnt=sum(f.count for f in frs),
252+
my_name="-".join([fr.nnn for fr in frs]),
253+
)
254+
],
255+
partition_by=C.t1.nnn,
256+
params="t1",
257+
output={"x": _TestFr},
265258
)
266259

267-
assert len(df) == 2
268-
assert df["x.my_name"].tolist() == ["n1-n1", "n2"]
269-
assert df["x.cnt"].tolist() == [4, 5]
260+
assert dc.collect_one("x") == [
261+
_TestFr(
262+
f=File(name=""),
263+
cnt=sum(fr.count for fr in features if fr.nnn == "n1"),
264+
my_name="-".join([fr.nnn for fr in features if fr.nnn == "n1"]),
265+
),
266+
_TestFr(
267+
f=File(name=""),
268+
cnt=sum(fr.count for fr in features if fr.nnn == "n2"),
269+
my_name="-".join([fr.nnn for fr in features if fr.nnn == "n2"]),
270+
),
271+
]
270272

271273

272274
def test_agg_two_params(catalog):
@@ -294,10 +296,8 @@ class _TestFr(BaseModel):
294296
output={"x": _TestFr},
295297
)
296298

297-
df = ds.to_pandas()
298-
assert len(df) == 2
299-
assert df["x.my_name"].tolist() == ["n1-n1", "n2"]
300-
assert df["x.cnt"].tolist() == [12, 15]
299+
assert ds.collect_one("x.my_name") == ["n1-n1", "n2"]
300+
assert ds.collect_one("x.cnt") == [12, 15]
301301

302302

303303
def test_agg_simple_iterator(catalog):
@@ -356,10 +356,8 @@ def func(key, val) -> Iterator[tuple[File, _ImageGroup]]:
356356
values = [1, 5, 9]
357357
ds = DataChain.from_values(key=keys, val=values).agg(x=func, partition_by=C("key"))
358358

359-
df = ds.to_pandas()
360-
assert len(df) == 2
361-
assert df["x_1.name"].tolist() == ["n1-n1", "n2"]
362-
assert df["x_1.size"].tolist() == [10, 5]
359+
assert ds.collect_one("x_1.name") == ["n1-n1", "n2"]
360+
assert ds.collect_one("x_1.size") == [10, 5]
363361

364362

365363
def test_agg_tuple_result_generator(catalog):
@@ -376,10 +374,8 @@ def func(key, val) -> Generator[tuple[File, _ImageGroup], None, None]:
376374
values = [1, 5, 9]
377375
ds = DataChain.from_values(key=keys, val=values).agg(x=func, partition_by=C("key"))
378376

379-
df = ds.to_pandas()
380-
assert len(df) == 2
381-
assert df["x_1.name"].tolist() == ["n1-n1", "n2"]
382-
assert df["x_1.size"].tolist() == [10, 5]
377+
assert ds.collect_one("x_1.name") == ["n1-n1", "n2"]
378+
assert ds.collect_one("x_1.size") == [10, 5]
383379

384380

385381
def test_iterate(catalog):
@@ -829,15 +825,15 @@ def test_from_features_object_name(tmp_dir, catalog):
829825
values = ["odd" if num % 2 else "even" for num in fib]
830826

831827
dc = DataChain.from_values(fib=fib, odds=values, object_name="custom")
832-
assert "custom.fib" in dc.to_pandas().columns
828+
assert "custom.fib" in dc.to_pandas(flatten=True).columns
833829

834830

835831
def test_parse_tabular_object_name(tmp_dir, catalog):
836832
df = pd.DataFrame(DF_DATA)
837833
path = tmp_dir / "test.parquet"
838834
df.to_parquet(path)
839-
dc = DataChain.from_storage(path.as_uri()).parse_tabular(object_name="name")
840-
assert "name.first_name" in dc.to_pandas().columns
835+
dc = DataChain.from_storage(path.as_uri()).parse_tabular(object_name="tbl")
836+
assert "tbl.first_name" in dc.to_pandas(flatten=True).columns
841837

842838

843839
def test_sys_feature(tmp_dir, catalog):
@@ -868,3 +864,12 @@ def test_sys_feature(tmp_dir, catalog):
868864
MyFr(nnn="n1", count=1),
869865
]
870866
assert "sys" not in ds_no_sys.catalog.get_dataset("ds_no_sys").feature_schema
867+
868+
869+
def test_to_pandas_multi_level():
870+
df = DataChain.from_values(t1=features).to_pandas()
871+
872+
assert "t1" in df.columns
873+
assert "nnn" in df["t1"].columns
874+
assert "count" in df["t1"].columns
875+
assert df["t1"]["count"].tolist() == [3, 5, 1]

0 commit comments

Comments
 (0)