@@ -108,11 +108,7 @@ def foreach(f, *args):
108108 return None
109109
110110else :
111- # TODO(phawkins): remove after jaxlib 0.5.2 is the minimum.
112- if hasattr (jaxlib_utils , 'foreach' ):
113- foreach = jaxlib_utils .foreach
114- else :
115- foreach = safe_map
111+ foreach = jaxlib_utils .foreach
116112
117113
118114def unzip2 (xys : Iterable [tuple [T1 , T2 ]]
@@ -244,61 +240,8 @@ def curry(f):
244240 """
245241 return wraps (f )(partial (partial , f ))
246242
247- # TODO(phawkins): make this unconditional after jaxlib 0.5.3 is the minimum.
248243toposort : Callable [[Iterable [Any ]], list [Any ]]
249- if hasattr (jaxlib_utils , "topological_sort" ):
250- toposort = partial (jaxlib_utils .topological_sort , "parents" )
251- else :
252-
253- def toposort (end_nodes ):
254- if not end_nodes :
255- return []
256- end_nodes = _remove_duplicates (end_nodes )
257-
258- child_counts = {}
259- stack = list (end_nodes )
260- while stack :
261- node = stack .pop ()
262- if id (node ) in child_counts :
263- child_counts [id (node )] += 1
264- else :
265- child_counts [id (node )] = 1
266- stack .extend (node .parents )
267- for node in end_nodes :
268- child_counts [id (node )] -= 1
269-
270- sorted_nodes = []
271- childless_nodes = [
272- node for node in end_nodes if child_counts [id (node )] == 0
273- ]
274- assert childless_nodes
275- while childless_nodes :
276- node = childless_nodes .pop ()
277- sorted_nodes .append (node )
278- for parent in node .parents :
279- if child_counts [id (parent )] == 1 :
280- childless_nodes .append (parent )
281- else :
282- child_counts [id (parent )] -= 1
283- sorted_nodes = sorted_nodes [::- 1 ]
284-
285- check_toposort (sorted_nodes )
286- return sorted_nodes
287-
288- def check_toposort (nodes ):
289- visited = set ()
290- for node in nodes :
291- assert all (id (parent ) in visited for parent in node .parents )
292- visited .add (id (node ))
293-
294- def _remove_duplicates (node_list ):
295- seen = set ()
296- out = []
297- for n in node_list :
298- if id (n ) not in seen :
299- seen .add (id (n ))
300- out .append (n )
301- return out
244+ toposort = partial (jaxlib_utils .topological_sort , "parents" )
302245
303246
304247def split_merge (predicate , xs ):
@@ -320,7 +263,6 @@ def merge(new_lhs, new_rhs):
320263
321264 return lhs , rhs , merge
322265
323-
324266def _ignore (): return None
325267
326268
0 commit comments