Skip to content

Commit 41a2794

Browse files
committed
feat: add support for preserving duplicate inputs in Keras model export; implement helper functions for tensor history
1 parent 3ad0e26 commit 41a2794

File tree

2 files changed

+87
-13
lines changed

2 files changed

+87
-13
lines changed

cerebros/keras_export.py

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,43 @@ def _tensor_name(t) -> str:
4444
return str(t)
4545

4646

47+
def _history_layer_of_tensor(t):
48+
"""Return producing layer for a Keras/TensorFlow tensor if available."""
49+
for attr in ("_keras_history", "keras_history"):
50+
if hasattr(t, attr):
51+
hist = getattr(t, attr)
52+
# Keras 3 may expose .layer; Keras 2 exposes tuple (layer, node_index, tensor_index)
53+
layer = getattr(hist, "layer", None)
54+
if layer is not None:
55+
return layer
56+
try:
57+
return hist[0]
58+
except Exception:
59+
return None
60+
return None
61+
62+
63+
def _upstream_layer_name_from_tensor(t) -> str:
64+
layer = _history_layer_of_tensor(t)
65+
return getattr(layer, "name", _tensor_name(t))
66+
67+
68+
def _flatten_inputs(obj):
69+
if obj is None:
70+
return []
71+
if isinstance(obj, (list, tuple)):
72+
out = []
73+
for x in obj:
74+
out.extend(_flatten_inputs(x))
75+
return out
76+
return [obj]
77+
78+
4779
def export_keras_to_graph(model) -> Dict[str, Any]:
4880
"""Export a tf.keras.Model to a minimal DAG spec.
4981
5082
Requires a Functional or Model graph (not a pure Sequential with no names).
83+
Preserves multiplicity and order of inbound connections.
5184
"""
5285
if tf is None:
5386
raise RuntimeError("TensorFlow not available; cannot export Keras model.")
@@ -57,23 +90,34 @@ def export_keras_to_graph(model) -> Dict[str, Any]:
5790
# Build list of nodes with inbound connections (layer order is already topological)
5891
for layer in model.layers:
5992
inputs: List[str] = []
93+
# Prefer deriving from the actual input tensors to preserve duplicates and order
94+
used_fallback = False
6095
try:
61-
# Keras 2/TF 2.x
96+
# Try node-based tensors first if available to reflect graph edges
6297
inbound_nodes = getattr(layer, "_inbound_nodes", [])
6398
for node in inbound_nodes:
64-
in_layers = getattr(node, "inbound_layers", [])
65-
if not isinstance(in_layers, (list, tuple)):
66-
in_layers = [in_layers]
67-
for in_l in in_layers:
68-
if in_l is None:
69-
continue
70-
inputs.append(in_l.name)
99+
kin = getattr(node, "keras_inputs", None) or getattr(node, "input_tensors", None)
100+
if kin is not None:
101+
for t in _flatten_inputs(kin):
102+
inputs.append(_upstream_layer_name_from_tensor(t))
103+
else:
104+
in_layers = getattr(node, "inbound_layers", [])
105+
if not isinstance(in_layers, (list, tuple)):
106+
in_layers = [in_layers]
107+
for in_l in in_layers:
108+
if in_l is None:
109+
continue
110+
inputs.append(in_l.name)
71111
except Exception:
72-
pass
73-
74-
# Deduplicate while preserving order
75-
seen = set()
76-
inputs = [x for x in inputs if not (x in seen or seen.add(x))]
112+
used_fallback = True
113+
114+
if used_fallback or not inputs:
115+
try:
116+
tensors = getattr(layer, "input", None)
117+
for t in _flatten_inputs(tensors):
118+
inputs.append(_upstream_layer_name_from_tensor(t))
119+
except Exception:
120+
pass
77121

78122
nodes.append(
79123
{
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import os
2+
# Force CPU for deterministic behavior
3+
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "-1")
4+
5+
import tensorflow as tf
6+
7+
from cerebros.keras_export import export_keras_to_graph
8+
9+
# Also disable GPUs at the TensorFlow level in case the env var is ignored
10+
try:
11+
tf.config.set_visible_devices([], "GPU")
12+
except Exception:
13+
pass
14+
15+
16+
def test_concatenate_preserves_duplicate_inputs():
17+
inp = tf.keras.Input(shape=(4,), name="inp")
18+
a = tf.keras.layers.Dense(2, activation="linear", name="a")(inp)
19+
b = tf.keras.layers.Dense(5, activation="linear", name="b")(inp)
20+
21+
cat = tf.keras.layers.Concatenate(axis=1, name="cat")([inp, inp, a, a, b, b])
22+
out = tf.keras.layers.Dense(1, name="out")(cat)
23+
model = tf.keras.Model(inputs=inp, outputs=out)
24+
25+
spec = export_keras_to_graph(model)
26+
# Find the cat node
27+
cat_nodes = [n for n in spec["nodes"] if n["name"] == "cat"]
28+
assert len(cat_nodes) == 1
29+
inputs = cat_nodes[0]["inputs"]
30+
assert inputs == ["inp", "inp", "a", "a", "b", "b"], inputs

0 commit comments

Comments
 (0)