Skip to content

Commit 1e5178b

Browse files
authored
Fix renaming object or normal signal with .mutate() (#217)
* fix renaming object or normal signal with mutate * removed print and added more tests * added test for has object * refactoring tests * added one assert in test * simplifying code, removing has_object * removing sys from tests * fixing test
1 parent 7832e10 commit 1e5178b

File tree

4 files changed

+132
-10
lines changed

4 files changed

+132
-10
lines changed

src/datachain/lib/dc.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -829,8 +829,19 @@ def mutate(self, **kwargs) -> "Self":
829829
)
830830
```
831831
"""
832-
chain = super().mutate(**kwargs)
833-
chain.signals_schema = self.signals_schema.mutate(kwargs)
832+
mutated = {}
833+
schema = self.signals_schema
834+
for name, value in kwargs.items():
835+
if isinstance(value, Column):
836+
# renaming existing column
837+
for signal in schema.db_signals(name=value.name, as_columns=True):
838+
mutated[signal.name.replace(value.name, name, 1)] = signal
839+
else:
840+
# adding new signal
841+
mutated[name] = value
842+
843+
chain = super().mutate(**mutated)
844+
chain.signals_schema = schema.mutate(kwargs)
834845
return chain
835846

836847
@property
@@ -1099,7 +1110,7 @@ def subtract( # type: ignore[override]
10991110
)
11001111
else:
11011112
signals = self.signals_schema.resolve(*on).db_signals()
1102-
return super()._subtract(other, signals)
1113+
return super()._subtract(other, signals) # type: ignore[arg-type]
11031114

11041115
@classmethod
11051116
def from_values(

src/datachain/lib/signal_schema.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from datachain.lib.file import File
2626
from datachain.lib.model_store import ModelStore
2727
from datachain.lib.utils import DataChainParamsError
28-
from datachain.query.schema import DEFAULT_DELIMITER
28+
from datachain.query.schema import DEFAULT_DELIMITER, Column
2929

3030
if TYPE_CHECKING:
3131
from datachain.catalog import Catalog
@@ -222,13 +222,30 @@ def row_to_features(
222222
res.append(obj)
223223
return res
224224

225-
def db_signals(self) -> list[str]:
226-
return [
225+
def db_signals(
226+
self, name: Optional[str] = None, as_columns=False
227+
) -> Union[list[str], list[Column]]:
228+
"""
229+
Returns DB columns as strings or Column objects with proper types
230+
Optionally, it can filter results by specific object, returning only his signals
231+
"""
232+
signals = [
227233
DEFAULT_DELIMITER.join(path)
228-
for path, _, has_subtree, _ in self.get_flat_tree()
234+
if not as_columns
235+
else Column(DEFAULT_DELIMITER.join(path), python_to_sql(_type))
236+
for path, _type, has_subtree, _ in self.get_flat_tree()
229237
if not has_subtree
230238
]
231239

240+
if name:
241+
signals = [
242+
s
243+
for s in signals
244+
if str(s) == name or str(s).startswith(f"{name}{DEFAULT_DELIMITER}")
245+
]
246+
247+
return signals # type: ignore[return-value]
248+
232249
def resolve(self, *names: str) -> "SignalSchema":
233250
schema = {}
234251
for field in names:
@@ -282,7 +299,18 @@ def clone_without_file_signals(self) -> "SignalSchema":
282299
return SignalSchema(schema)
283300

284301
def mutate(self, args_map: dict) -> "SignalSchema":
285-
return SignalSchema(self.values | sql_to_python(args_map))
302+
new_values = self.values.copy()
303+
304+
for name, value in args_map.items():
305+
if isinstance(value, Column) and value.name in self.values:
306+
# renaming existing signal
307+
del new_values[value.name]
308+
new_values[name] = self.values[value.name]
309+
else:
310+
# adding new signal
311+
new_values.update(sql_to_python({name: value}))
312+
313+
return SignalSchema(new_values)
286314

287315
def clone_without_sys_signals(self) -> "SignalSchema":
288316
schema = copy.deepcopy(self.values)

tests/unit/lib/test_datachain.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,3 +1304,59 @@ def func(key, val) -> Iterator[tuple[File, str]]:
13041304
assert ds.limit(3).gen(res=func).limit(2).count() == 2
13051305
assert ds.limit(2).gen(res=func).limit(3).count() == 3
13061306
assert ds.limit(3).gen(res=func).limit(10).count() == 9
1307+
1308+
1309+
def test_rename_non_object_column_name_with_mutate(catalog):
1310+
ds = DataChain.from_values(ids=[1, 2, 3])
1311+
ds = ds.mutate(my_ids=Column("ids"))
1312+
1313+
assert ds.signals_schema.values == {"my_ids": int}
1314+
assert list(ds.order_by("my_ids").collect("my_ids")) == [1, 2, 3]
1315+
1316+
ds.save("mutated")
1317+
1318+
ds = DataChain(name="mutated")
1319+
assert ds.signals_schema.values.get("my_ids") is int
1320+
assert "ids" not in ds.signals_schema.values
1321+
assert list(ds.order_by("my_ids").collect("my_ids")) == [1, 2, 3]
1322+
1323+
1324+
def test_rename_object_column_name_with_mutate(catalog):
1325+
names = ["a", "b", "c"]
1326+
sizes = [1, 2, 3]
1327+
files = [File(name=name, size=size) for name, size in zip(names, sizes)]
1328+
1329+
ds = DataChain.from_values(file=files, ids=[1, 2, 3])
1330+
ds = ds.mutate(fname=Column("file.name"))
1331+
1332+
assert list(ds.order_by("fname").collect("fname")) == ["a", "b", "c"]
1333+
assert ds.signals_schema.values == {"file": File, "ids": int, "fname": str}
1334+
1335+
# check that persist after saving
1336+
ds.save("mutated")
1337+
1338+
ds = DataChain(name="mutated")
1339+
assert ds.signals_schema.values.get("file") is File
1340+
assert ds.signals_schema.values.get("ids") is int
1341+
assert ds.signals_schema.values.get("fname") is str
1342+
assert list(ds.order_by("fname").collect("fname")) == ["a", "b", "c"]
1343+
1344+
1345+
def test_rename_object_name_with_mutate(catalog):
1346+
names = ["a", "b", "c"]
1347+
sizes = [1, 2, 3]
1348+
files = [File(name=name, size=size) for name, size in zip(names, sizes)]
1349+
1350+
ds = DataChain.from_values(file=files, ids=[1, 2, 3])
1351+
ds = ds.mutate(my_file=Column("file"))
1352+
1353+
assert list(ds.order_by("my_file.name").collect("my_file.name")) == ["a", "b", "c"]
1354+
assert ds.signals_schema.values == {"my_file": File, "ids": int}
1355+
1356+
ds.save("mutated")
1357+
1358+
ds = DataChain(name="mutated")
1359+
assert ds.signals_schema.values.get("my_file") is File
1360+
assert ds.signals_schema.values.get("ids") is int
1361+
assert "file" not in ds.signals_schema.values
1362+
assert list(ds.order_by("my_file.name").collect("my_file.name")) == ["a", "b", "c"]

tests/unit/lib/test_signal_schema.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55

6-
from datachain import DataModel
6+
from datachain import Column, DataModel
77
from datachain.lib.convert.flatten import flatten
88
from datachain.lib.file import File
99
from datachain.lib.signal_schema import (
@@ -251,7 +251,7 @@ def test_print_types():
251251
assert SignalSchema._type_to_str(t) == v
252252

253253

254-
def test_bd_signals():
254+
def test_db_signals():
255255
spec = {"name": str, "age": float, "fr": MyType2}
256256
lst = list(SignalSchema(spec).db_signals())
257257

@@ -264,6 +264,33 @@ def test_bd_signals():
264264
]
265265

266266

267+
def test_db_signals_filtering_by_name():
268+
schema = SignalSchema({"name": str, "age": float, "fr": MyType2})
269+
270+
assert list(schema.db_signals(name="fr")) == [
271+
"fr__name",
272+
"fr__deep__aa",
273+
"fr__deep__bb",
274+
]
275+
assert list(schema.db_signals(name="name")) == ["name"]
276+
assert list(schema.db_signals(name="missing")) == []
277+
278+
279+
def test_db_signals_as_columns():
280+
spec = {"name": str, "age": float, "fr": MyType2}
281+
lst = list(SignalSchema(spec).db_signals(as_columns=True))
282+
283+
assert all(isinstance(s, Column) for s in lst)
284+
285+
assert [(c.name, type(c.type)) for c in lst] == [
286+
("name", String),
287+
("age", Float),
288+
("fr__name", String),
289+
("fr__deep__aa", Int64),
290+
("fr__deep__bb", String),
291+
]
292+
293+
267294
def test_row_to_objs():
268295
spec = {"name": str, "age": float, "fr": MyType2}
269296
schema = SignalSchema(spec)

0 commit comments

Comments
 (0)