diff --git a/kauldron/data/transforms/base.py b/kauldron/data/transforms/base.py index 092438d6..302105e8 100644 --- a/kauldron/data/transforms/base.py +++ b/kauldron/data/transforms/base.py @@ -35,6 +35,7 @@ class Elements(tr_abc.MapTransform): drop: Iterable[str] = () rename: Mapping[str, str] = _FrozenDict() copy: Mapping[str, str] = _FrozenDict() + skip_missing: bool = False def __post_init__(self): if self.keep and self.drop: @@ -83,7 +84,7 @@ def map(self, features): if bool(self.copy): copy_keys = set(self.copy.keys()) missing_copy_keys = copy_keys - feature_keys - if missing_copy_keys: + if missing_copy_keys and not self.skip_missing: raise KeyError( f"copy-key(s) {missing_copy_keys} not found in batch. " f"Available keys are {sorted(feature_keys)!r}." @@ -95,13 +96,16 @@ def map(self, features): f"copy-value(s) {overlap_keys} will overwrite existing values in " f"batch. Existing keys are {sorted(feature_keys)!r}." ) - copy_output = {v: features[k] for k, v in self.copy.items()} + copy_output = {} + for k, v in self.copy.items(): + if k in features: + copy_output[v] = features[k] # resolve keep or drop if self.keep: keep_keys = set(self.keep) missing_keep_keys = keep_keys - feature_keys - if missing_keep_keys: + if missing_keep_keys and not self.skip_missing: raise KeyError( f"keep-key(s) {missing_keep_keys} not found in batch. " f"Available keys are {sorted(feature_keys)!r}." @@ -110,7 +114,7 @@ def map(self, features): elif self.drop: drop_keys = set(self.drop) missing_drop_keys = drop_keys - feature_keys - if missing_drop_keys: + if missing_drop_keys and not self.skip_missing: raise KeyError( f"drop-key(s) {missing_drop_keys} not found in batch. " f"Available keys are {sorted(feature_keys)!r}." @@ -127,7 +131,7 @@ def map(self, features): # resolve renaming rename_keys = set(self.rename.keys()) missing_rename_keys = rename_keys - feature_keys - if missing_rename_keys: + if missing_rename_keys and not self.skip_missing: raise KeyError( f"rename-key(s) {missing_rename_keys} not found in batch. " f"Available keys are {sorted(feature_keys)!r}." diff --git a/kauldron/data/transforms/base_test.py b/kauldron/data/transforms/base_test.py index 219808a7..71dfdb73 100644 --- a/kauldron/data/transforms/base_test.py +++ b/kauldron/data/transforms/base_test.py @@ -29,6 +29,22 @@ def test_elements_keep(): assert after["no_copy"] == before["no"] +def test_elements_keep_skip_missing(): + el = kd.data.py.Elements( + keep={"yes", "definitely", "missing"}, + rename={"old": "new", "old_missing": "new_missing"}, + copy={"no": "no_copy", "missing": "missing_copy"}, + skip_missing=True, + ) + before = {"yes": 1, "definitely": 2, "old": 3, "no": 4, "drop": 5} + after = el.map(before) + assert set(after.keys()) == {"yes", "definitely", "new", "no_copy"} + assert after["yes"] == before["yes"] + assert after["definitely"] == before["definitely"] + assert after["new"] == before["old"] + assert after["no_copy"] == before["no"] + + def test_elements_drop(): el = kd.data.py.Elements( drop={"no", "drop"}, rename={"old": "new"}, copy={"yes": "yes_copy"} @@ -42,6 +58,22 @@ def test_elements_drop(): assert after["yes_copy"] == before["yes"] +def test_elements_drop_skip_missing(): + el = kd.data.py.Elements( + drop={"no", "drop", "missing"}, + rename={"old": "new", "old_missing": "new_missing"}, + copy={"yes": "yes_copy", "missing": "missing_copy"}, + skip_missing=True, + ) + before = {"yes": 1, "definitely": 2, "old": 3, "no": 4, "drop": 5} + after = el.map(before) + assert set(after.keys()) == {"yes", "definitely", "new", "yes_copy"} + assert after["yes"] == before["yes"] + assert after["definitely"] == before["definitely"] + assert after["new"] == before["old"] + assert after["yes_copy"] == before["yes"] + + def test_elements_rename_only(): el = kd.data.py.Elements(rename={"old": "new"}) before = {"yes": 1, "definitely": 2, "old": 3, "no": 4, "drop": 5} @@ -60,6 +92,12 @@ def test_elements_rename_overwrite_raises(): with pytest.raises(KeyError): el.map(before) + # Same as above but with skip_missing=True. + el = kd.data.py.Elements(rename={"old": "oops"}, skip_missing=True) + before = {"old": 1, "oops": 2} + with pytest.raises(KeyError): + el.map(before) + def test_elements_copy_only(): el = kd.data.py.Elements(copy={"yes": "no", "old": "new"}) @@ -84,3 +122,26 @@ def test_elements_copy_overwrite_raises(): # copy two fields to the same target name with pytest.raises(ValueError): _ = kd.data.py.Elements(copy={"old": "oops", "yes": "oops"}) + + +def test_elements_copy_overwrite_raises_skip_missing(): + # copy to an existing key + el = kd.data.py.Elements( + copy={"old": "oops", "missing": "missing_copy"}, skip_missing=True + ) + before = {"old": 1, "oops": 2} + with pytest.raises(KeyError): + el.map(before) + # copy to a key that is also a rename target + with pytest.raises(KeyError): + _ = kd.data.py.Elements( + copy={"old": "oops"}, + rename={"yes": "oops", "missing": "missing_remamed"}, + skip_missing=True, + ) + # copy two fields to the same target name + with pytest.raises(ValueError): + _ = kd.data.py.Elements( + copy={"old": "oops", "yes": "oops", "missing": "missing_copy"}, + skip_missing=True, + )