Skip to content

Commit c32e527

Browse files
author
The kauldron Authors
committed
Refactor training transforms to reuse the inference transforms.
PiperOrigin-RevId: 869786054
1 parent 66af85d commit c32e527

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

kauldron/contrib/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
from kauldron.contrib.data.preprocessing import TemporalRandomWalk
7373
from kauldron.contrib.data.preprocessing import TemporalRandomWindow
7474
from kauldron.contrib.data.preprocessing import TimeChunkedFlattenVideo
75+
from kauldron.contrib.data.preprocessing import TreeUnflattenForKey
7576
from kauldron.contrib.data.preprocessing import VStack
7677
from kauldron.contrib.data.preprocessing import ValueRange
7778
from kauldron.contrib.data.preprocessing import VideoMAENormalization

kauldron/contrib/data/preprocessing.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import abc
2020
from collections.abc import Mapping
2121
import dataclasses
22+
import functools
2223
import typing
2324
from typing import Any, Callable, List, Literal, Optional, Sequence
2425

@@ -1630,3 +1631,63 @@ def map(self, features: dict[str, tf.Tensor]) -> dict[str, tf.Tensor]:
16301631

16311632
features[self.output_key] = final_video
16321633
return features
1634+
1635+
1636+
@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
1637+
class TreeUnflattenForKey(kd.data.MapTransform):
1638+
"""Unflattens a previously flattened dictionary within an element.
1639+
1640+
This transform is designed to reverse the effect of
1641+
kd.data.TreeFlattenWithPath.
1642+
"""
1643+
1644+
key: str
1645+
separator: str = "_"
1646+
1647+
@functools.cached_property
1648+
def _prefix(self) -> str:
1649+
return self.key + self.separator
1650+
1651+
def map(self, element: dict[str, Any]) -> dict[str, Any]:
1652+
element = dict(element) # Ensure the element is mutable
1653+
flat_subtree = {}
1654+
keys_to_remove = []
1655+
1656+
# Extract keys belonging to the flattened subtree
1657+
for k, v in element.items():
1658+
if k.startswith(self._prefix):
1659+
flat_subtree[k] = v
1660+
keys_to_remove.append(k)
1661+
1662+
if not flat_subtree:
1663+
# No keys to unflatten for the specified key
1664+
return element
1665+
1666+
# Reconstruct the nested dictionary
1667+
nested_subtree = {}
1668+
for long_key, value in flat_subtree.items():
1669+
# Remove the prefix
1670+
path_str = long_key[len(self._prefix) :]
1671+
parts = path_str.split(self.separator)
1672+
1673+
current_level = nested_subtree
1674+
for i, part in enumerate(parts):
1675+
if i == len(parts) - 1:
1676+
# Last part of the path, assign the value
1677+
current_level[part] = value
1678+
else:
1679+
# Navigate/create nested dictionaries
1680+
if part not in current_level:
1681+
current_level[part] = {}
1682+
elif not isinstance(current_level[part], dict):
1683+
# This case should ideally not happen if flattening was consistent
1684+
# Handle potential conflicts if a path is both a leaf and a branch
1685+
raise ValueError(f"Conflict at path {'.'.join(parts[:i+1])}")
1686+
current_level = current_level[part]
1687+
1688+
# Update the element: add the nested structure, remove old flat keys
1689+
element[self.key] = nested_subtree
1690+
for k in keys_to_remove:
1691+
del element[k]
1692+
1693+
return element

0 commit comments

Comments
 (0)