|
19 | 19 | import abc |
20 | 20 | from collections.abc import Mapping |
21 | 21 | import dataclasses |
| 22 | +import functools |
22 | 23 | import typing |
23 | 24 | from typing import Any, Callable, List, Literal, Optional, Sequence |
24 | 25 |
|
@@ -1630,3 +1631,63 @@ def map(self, features: dict[str, tf.Tensor]) -> dict[str, tf.Tensor]: |
1630 | 1631 |
|
1631 | 1632 | features[self.output_key] = final_video |
1632 | 1633 | return features |
| 1634 | + |
| 1635 | + |
| 1636 | +@dataclasses.dataclass(kw_only=True, frozen=True, eq=True) |
| 1637 | +class TreeUnflattenForKey(kd.data.MapTransform): |
| 1638 | + """Unflattens a previously flattened dictionary within an element. |
| 1639 | +
|
| 1640 | + This transform is designed to reverse the effect of |
| 1641 | + kd.data.TreeFlattenWithPath. |
| 1642 | + """ |
| 1643 | + |
| 1644 | + key: str |
| 1645 | + separator: str = "_" |
| 1646 | + |
| 1647 | + @functools.cached_property |
| 1648 | + def _prefix(self) -> str: |
| 1649 | + return self.key + self.separator |
| 1650 | + |
| 1651 | + def map(self, element: dict[str, Any]) -> dict[str, Any]: |
| 1652 | + element = dict(element) # Ensure the element is mutable |
| 1653 | + flat_subtree = {} |
| 1654 | + keys_to_remove = [] |
| 1655 | + |
| 1656 | + # Extract keys belonging to the flattened subtree |
| 1657 | + for k, v in element.items(): |
| 1658 | + if k.startswith(self._prefix): |
| 1659 | + flat_subtree[k] = v |
| 1660 | + keys_to_remove.append(k) |
| 1661 | + |
| 1662 | + if not flat_subtree: |
| 1663 | + # No keys to unflatten for the specified key |
| 1664 | + return element |
| 1665 | + |
| 1666 | + # Reconstruct the nested dictionary |
| 1667 | + nested_subtree = {} |
| 1668 | + for long_key, value in flat_subtree.items(): |
| 1669 | + # Remove the prefix |
| 1670 | + path_str = long_key[len(self._prefix) :] |
| 1671 | + parts = path_str.split(self.separator) |
| 1672 | + |
| 1673 | + current_level = nested_subtree |
| 1674 | + for i, part in enumerate(parts): |
| 1675 | + if i == len(parts) - 1: |
| 1676 | + # Last part of the path, assign the value |
| 1677 | + current_level[part] = value |
| 1678 | + else: |
| 1679 | + # Navigate/create nested dictionaries |
| 1680 | + if part not in current_level: |
| 1681 | + current_level[part] = {} |
| 1682 | + elif not isinstance(current_level[part], dict): |
| 1683 | + # This case should ideally not happen if flattening was consistent |
| 1684 | + # Handle potential conflicts if a path is both a leaf and a branch |
| 1685 | + raise ValueError(f"Conflict at path {'.'.join(parts[:i+1])}") |
| 1686 | + current_level = current_level[part] |
| 1687 | + |
| 1688 | + # Update the element: add the nested structure, remove old flat keys |
| 1689 | + element[self.key] = nested_subtree |
| 1690 | + for k in keys_to_remove: |
| 1691 | + del element[k] |
| 1692 | + |
| 1693 | + return element |
0 commit comments