Skip to content

Commit fa9c4ba

Browse files
Add torchtree_impl to support torch.compile with torch 2.8.0. (#21661)
* Add `torchtree_impl` to support torch.compile with torch 2.8.0. * Update. * Update `torchtree_impl`. * Update jax==0.6.2, torch==2.8.0 and tensorflow==2.20.0. * Revert the backend version. * Fix tests. * Prevent the torch compiler from breaking.
1 parent d19fece commit fa9c4ba

File tree

9 files changed

+370
-34
lines changed

9 files changed

+370
-34
lines changed

keras/src/trainers/compile_utils.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -690,17 +690,34 @@ def __call__(self, y_true, y_pred, sample_weight=None):
690690
return self.call(y_true, y_pred, sample_weight)
691691

692692
def call(self, y_true, y_pred, sample_weight=None):
693+
def resolve_path(path, object):
694+
for _path in path:
695+
object = object[_path]
696+
return object
697+
693698
if not tree.is_nested(y_true) and not tree.is_nested(y_pred):
694699
# Fast path: single output case / no loss-tracking metric.
695700
if not self.built:
696701
self.build(y_true, y_pred)
697-
_, loss_fn, loss_weight, _ = self._flat_losses[0]
698-
loss_value = ops.cast(
699-
loss_fn(y_true, y_pred, sample_weight), dtype=self.dtype
700-
)
701-
if loss_weight is not None:
702-
loss_value = ops.multiply(loss_value, loss_weight)
703-
return loss_value
702+
# Although we are in the fast path, we still need to iterate
703+
# through the losses to prevent the torch compiler from failing.
704+
loss_values = []
705+
for path, loss_fn, loss_weight, _ in self._flat_losses:
706+
y_t, y_p = (
707+
resolve_path(path, y_true),
708+
resolve_path(path, y_pred),
709+
)
710+
if sample_weight is not None and tree.is_nested(sample_weight):
711+
_sample_weight = resolve_path(path, sample_weight)
712+
else:
713+
_sample_weight = sample_weight
714+
value = ops.cast(
715+
loss_fn(y_t, y_p, _sample_weight), dtype=self.dtype
716+
)
717+
if loss_weight is not None:
718+
value = ops.multiply(value, loss_weight)
719+
loss_values.append(value)
720+
return loss_values[0]
704721

705722
try:
706723
tree.assert_same_structure(y_pred, y_true)
@@ -779,11 +796,6 @@ def call(self, y_true, y_pred, sample_weight=None):
779796
# Iterate all losses in flat form.
780797
loss_values = []
781798

782-
def resolve_path(path, object):
783-
for _path in path:
784-
object = object[_path]
785-
return object
786-
787799
for (path, loss_fn, loss_weight, _), metric in zip(
788800
self._flat_losses, metrics
789801
):

keras/src/trainers/trainer_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1869,6 +1869,11 @@ def test_training_arg(self):
18691869
)
18701870
@pytest.mark.requires_trainable_backend
18711871
def test_on_batch_methods(self, run_eagerly, jit_compile):
1872+
if backend.backend() == "torch" and jit_compile:
1873+
self.skipTest(
1874+
"test_on_batch with jit_compile=True not supported in torch "
1875+
"backend yet."
1876+
)
18721877
model = ExampleModel(units=3)
18731878
x = np.ones((100, 4))
18741879
y = np.zeros((100, 3))
@@ -1925,6 +1930,11 @@ def test_on_batch_methods(self, run_eagerly, jit_compile):
19251930
]
19261931
)
19271932
def test_on_batch_methods_without_training(self, run_eagerly, jit_compile):
1933+
if backend.backend() == "torch" and jit_compile:
1934+
self.skipTest(
1935+
"test_on_batch with jit_compile=True not supported in torch "
1936+
"backend yet."
1937+
)
19281938
model = ExampleModel(units=3)
19291939
x = np.ones((100, 4))
19301940
y = np.zeros((100, 3))

keras/src/tree/torchtree_impl.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
from collections import defaultdict
2+
3+
from torch.utils import _pytree as torch_tree
4+
5+
6+
def register_tree_node_class(cls):
7+
torch_tree.register_pytree_node(
8+
cls,
9+
flatten_fn=lambda x: x.torchtree_flatten(),
10+
unflatten_fn=cls.torchtree_unflatten,
11+
serialized_type_name=f"{cls.__name__}",
12+
flatten_with_keys_fn=lambda x: x.torchtree_flatten_with_keys(),
13+
)
14+
return cls
15+
16+
17+
def _tree_is_leaf(tree, is_leaf=None):
18+
if is_leaf is not None and is_leaf(tree):
19+
return True
20+
return torch_tree._get_node_type(tree) not in torch_tree.SUPPORTED_NODES
21+
22+
23+
def _dict_to_ordered_dict(structure):
24+
# We need to sort dict and defaultdict to ensure a deterministic order that
25+
# that is consistent with other tree implementations.
26+
def func(x):
27+
if type(x) is dict:
28+
return {k: x[k] for k in sorted(x.keys())}
29+
elif type(x) is defaultdict:
30+
return defaultdict(
31+
x.default_factory,
32+
{k: x[k] for k in sorted(x.keys())},
33+
)
34+
return None
35+
36+
def traverse_children():
37+
children, treedef = torch_tree.tree_flatten(
38+
structure,
39+
is_leaf=lambda x: x is not structure,
40+
)
41+
if treedef.num_nodes == 1 and treedef.num_leaves == 1:
42+
return structure
43+
else:
44+
return torch_tree.tree_unflatten(
45+
[_dict_to_ordered_dict(c) for c in children],
46+
treedef,
47+
)
48+
49+
ret = func(structure)
50+
if ret is None:
51+
return traverse_children()
52+
if isinstance(ret, type) and ret.__name__ == "MAP_TO_NONE":
53+
return None
54+
return ret
55+
56+
57+
def is_nested(structure):
58+
return not _tree_is_leaf(structure)
59+
60+
61+
def traverse(func, structure, top_down=True):
62+
def traverse_children():
63+
children, treedef = torch_tree.tree_flatten(
64+
structure,
65+
is_leaf=lambda x: x is not structure,
66+
)
67+
if treedef.num_nodes == 1 and treedef.num_leaves == 1:
68+
return structure
69+
else:
70+
return torch_tree.tree_unflatten(
71+
[traverse(func, c, top_down=top_down) for c in children],
72+
treedef,
73+
)
74+
75+
structure = _dict_to_ordered_dict(structure)
76+
if top_down:
77+
ret = func(structure)
78+
if ret is None:
79+
return traverse_children()
80+
else:
81+
traversed_structure = traverse_children()
82+
ret = func(traversed_structure)
83+
if ret is None:
84+
return traversed_structure
85+
# Detect MAP_TO_NONE without tree_api import to avoid circular import.
86+
if isinstance(ret, type) and ret.__name__ == "MAP_TO_NONE":
87+
return None
88+
return ret
89+
90+
91+
def flatten(structure):
92+
# We need to first sort dicts to ensure a deterministic order that is
93+
# consistent with other tree implementations.
94+
structure = _dict_to_ordered_dict(structure)
95+
leaves, _ = torch_tree.tree_flatten(structure)
96+
return leaves
97+
98+
99+
def flatten_with_path(structure):
100+
# We need to first sort dicts to ensure a deterministic order that is
101+
# consistent with other tree implementations.
102+
structure = _dict_to_ordered_dict(structure)
103+
leaves_with_path, _ = torch_tree.tree_flatten_with_path(structure)
104+
results = []
105+
fields = []
106+
for key, leaf in leaves_with_path:
107+
for k in key:
108+
if isinstance(k, torch_tree.GetAttrKey) and k.name not in fields:
109+
fields.append(k.name)
110+
fields = sorted(fields)
111+
field_to_idx = {f: i for i, f in enumerate(fields)}
112+
for key, leaf in leaves_with_path:
113+
# Convert to a tuple of keys.
114+
path = []
115+
for k in key:
116+
if isinstance(k, torch_tree.SequenceKey):
117+
path.append(k.idx)
118+
elif isinstance(k, torch_tree.MappingKey):
119+
path.append(k.key)
120+
elif isinstance(k, torch_tree.GetAttrKey):
121+
path.append(field_to_idx[k.name])
122+
results.append((tuple(path), leaf))
123+
return results
124+
125+
126+
def map_structure(func, *structures, none_is_leaf=True):
127+
if not structures:
128+
raise ValueError("Must provide at least one structure")
129+
130+
map_func = func
131+
if not none_is_leaf:
132+
133+
def func_skipping_none(*args):
134+
# Check if the reference entry (first one) is None
135+
if args[0] is None:
136+
if not all(s is None for s in args):
137+
raise ValueError(
138+
"Structure mismatch: some arguments are None, others "
139+
f"are not. Received arguments: {args}."
140+
)
141+
return None
142+
return func(*args)
143+
144+
map_func = func_skipping_none
145+
146+
return torch_tree.tree_map(map_func, *structures)
147+
148+
149+
def map_structure_up_to(shallow_structure, func, *structures):
150+
if not structures:
151+
raise ValueError("Must provide at least one structure")
152+
153+
# Add check that `shallow_structure` really is the shallowest.
154+
# Also only call `func` on `structures` and not `shallow_structure`.
155+
def func_with_check_without_shallow_structure(shallow, *args):
156+
if not _tree_is_leaf(shallow):
157+
raise ValueError("Structures don't have the same nested structure.")
158+
return func(*args)
159+
160+
return torch_tree.tree_map(
161+
func_with_check_without_shallow_structure,
162+
shallow_structure,
163+
*structures,
164+
)
165+
166+
167+
def assert_same_structure(a, b):
168+
def check(a_leaf, b_leaf):
169+
if not _tree_is_leaf(a_leaf) or not _tree_is_leaf(b_leaf):
170+
raise ValueError("Structures don't have the same nested structure.")
171+
return None
172+
173+
torch_tree.tree_map(check, a, b)
174+
175+
176+
def assert_same_paths(a, b):
177+
a_paths = set([path for path, _ in flatten_with_path(a)])
178+
b_paths = set([path for path, _ in flatten_with_path(b)])
179+
180+
if a_paths != b_paths:
181+
msg = "`a` and `b` don't have the same paths."
182+
a_diff = a_paths.difference(b_paths)
183+
if a_diff:
184+
msg += f"\nPaths in `a` missing in `b`:\n{a_diff}"
185+
b_diff = b_paths.difference(a_paths)
186+
if b_diff:
187+
msg += f"\nPaths in `b` missing in `a`:\n{b_diff}"
188+
raise ValueError(msg)
189+
190+
191+
def pack_sequence_as(structure, flat_sequence):
192+
# We need to first sort dicts to ensure a deterministic order that is
193+
# consistent with other tree implementations.
194+
structure = _dict_to_ordered_dict(structure)
195+
_, treespec = torch_tree.tree_flatten(structure)
196+
return torch_tree.tree_unflatten(flat_sequence, treespec)
197+
198+
199+
def lists_to_tuples(structure):
200+
def list_to_tuple(instance):
201+
return tuple(instance) if isinstance(instance, list) else None
202+
203+
return traverse(list_to_tuple, structure, top_down=False)
204+
205+
206+
def map_shape_structure(func, structure):
207+
def is_shape_tuple(x):
208+
return isinstance(x, (list, tuple)) and all(
209+
isinstance(e, (int, type(None))) for e in x
210+
)
211+
212+
# We need to first sort dicts to ensure a deterministic order that is
213+
# consistent with other tree implementations.
214+
structure = _dict_to_ordered_dict(structure)
215+
return torch_tree.tree_map(func, structure, is_leaf=is_shape_tuple)

keras/src/tree/tree_api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import warnings
22

33
from keras.src.api_export import keras_export
4+
from keras.src.backend.config import backend
45
from keras.src.utils.module_utils import dmtree
56
from keras.src.utils.module_utils import optree
67

7-
if optree.available:
8+
if backend() == "torch":
9+
# torchtree_impl is especially used for Torch backend, as it works better
10+
# with torch.compile.
11+
from keras.src.tree import torchtree_impl as tree_impl
12+
elif optree.available:
813
from keras.src.tree import optree_impl as tree_impl
914
elif dmtree.available:
1015
from keras.src.tree import dmtree_impl as tree_impl

0 commit comments

Comments
 (0)