@@ -44,10 +44,43 @@ def _tensor_name(t) -> str:
4444 return str (t )
4545
4646
47+ def _history_layer_of_tensor (t ):
48+ """Return producing layer for a Keras/TensorFlow tensor if available."""
49+ for attr in ("_keras_history" , "keras_history" ):
50+ if hasattr (t , attr ):
51+ hist = getattr (t , attr )
52+ # Keras 3 may expose .layer; Keras 2 exposes tuple (layer, node_index, tensor_index)
53+ layer = getattr (hist , "layer" , None )
54+ if layer is not None :
55+ return layer
56+ try :
57+ return hist [0 ]
58+ except Exception :
59+ return None
60+ return None
61+
62+
63+ def _upstream_layer_name_from_tensor (t ) -> str :
64+ layer = _history_layer_of_tensor (t )
65+ return getattr (layer , "name" , _tensor_name (t ))
66+
67+
68+ def _flatten_inputs (obj ):
69+ if obj is None :
70+ return []
71+ if isinstance (obj , (list , tuple )):
72+ out = []
73+ for x in obj :
74+ out .extend (_flatten_inputs (x ))
75+ return out
76+ return [obj ]
77+
78+
4779def export_keras_to_graph (model ) -> Dict [str , Any ]:
4880 """Export a tf.keras.Model to a minimal DAG spec.
4981
5082 Requires a Functional or Model graph (not a pure Sequential with no names).
83+ Preserves multiplicity and order of inbound connections.
5184 """
5285 if tf is None :
5386 raise RuntimeError ("TensorFlow not available; cannot export Keras model." )
@@ -57,23 +90,34 @@ def export_keras_to_graph(model) -> Dict[str, Any]:
5790 # Build list of nodes with inbound connections (layer order is already topological)
5891 for layer in model .layers :
5992 inputs : List [str ] = []
93+ # Prefer deriving from the actual input tensors to preserve duplicates and order
94+ used_fallback = False
6095 try :
61- # Keras 2/TF 2.x
96+ # Try node-based tensors first if available to reflect graph edges
6297 inbound_nodes = getattr (layer , "_inbound_nodes" , [])
6398 for node in inbound_nodes :
64- in_layers = getattr (node , "inbound_layers" , [])
65- if not isinstance (in_layers , (list , tuple )):
66- in_layers = [in_layers ]
67- for in_l in in_layers :
68- if in_l is None :
69- continue
70- inputs .append (in_l .name )
99+ kin = getattr (node , "keras_inputs" , None ) or getattr (node , "input_tensors" , None )
100+ if kin is not None :
101+ for t in _flatten_inputs (kin ):
102+ inputs .append (_upstream_layer_name_from_tensor (t ))
103+ else :
104+ in_layers = getattr (node , "inbound_layers" , [])
105+ if not isinstance (in_layers , (list , tuple )):
106+ in_layers = [in_layers ]
107+ for in_l in in_layers :
108+ if in_l is None :
109+ continue
110+ inputs .append (in_l .name )
71111 except Exception :
72- pass
73-
74- # Deduplicate while preserving order
75- seen = set ()
76- inputs = [x for x in inputs if not (x in seen or seen .add (x ))]
112+ used_fallback = True
113+
114+ if used_fallback or not inputs :
115+ try :
116+ tensors = getattr (layer , "input" , None )
117+ for t in _flatten_inputs (tensors ):
118+ inputs .append (_upstream_layer_name_from_tensor (t ))
119+ except Exception :
120+ pass
77121
78122 nodes .append (
79123 {
0 commit comments