Skip to content

Commit a8bc786

Browse files
authored
fix select then distinct chain (#213)
1 parent 6f276fc commit a8bc786

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

src/datachain/lib/signal_schema.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,11 @@ def _find_in_tree(self, path: list[str]) -> DataType:
243243
curr_type = None
244244
i = 0
245245
while curr_tree is not None and i < len(path):
246-
if val := curr_tree.get(path[i], None):
246+
if val := curr_tree.get(path[i]):
247247
curr_type, curr_tree = val
248+
elif i == 0 and len(path) > 1 and (val := curr_tree.get(".".join(path))):
249+
curr_type, curr_tree = val
250+
break
248251
else:
249252
curr_type = None
250253
i += 1

tests/unit/lib/test_datachain.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,41 @@ def test_select_restore_from_saving(catalog):
612612
assert n == len(features_nested)
613613

614614

615+
def test_select_distinct(catalog):
616+
class Embedding(BaseModel):
617+
id: int
618+
filename: str
619+
values: list[float]
620+
621+
expected = [
622+
[0.1, 0.3],
623+
[0.1, 0.4],
624+
[0.1, 0.5],
625+
[0.1, 0.6],
626+
]
627+
628+
actual = (
629+
DataChain.from_values(
630+
embedding=[
631+
Embedding(id=1, filename="a.jpg", values=expected[0]),
632+
Embedding(id=2, filename="b.jpg", values=expected[2]),
633+
Embedding(id=3, filename="c.jpg", values=expected[1]),
634+
Embedding(id=4, filename="d.jpg", values=expected[1]),
635+
Embedding(id=5, filename="e.jpg", values=expected[3]),
636+
],
637+
)
638+
.select("embedding.values", "embedding.filename")
639+
.distinct("embedding.values")
640+
.order_by("embedding.values")
641+
.collect()
642+
)
643+
644+
actual = [emb[0] for emb in actual]
645+
assert len(actual) == 4
646+
for i in [0, 1]:
647+
assert np.allclose([emb[i] for emb in actual], [emp[i] for emp in expected])
648+
649+
615650
def test_from_dataset_name_version(catalog):
616651
name = "test-version"
617652
DataChain.from_values(

0 commit comments

Comments
 (0)