11from pytensor .link .basic import JITLinker
2- from pytensor .link .utils import unique_name_generator
32
43
54class MLXLinker (JITLinker ):
@@ -9,29 +8,13 @@ def __init__(self, *args, **kwargs):
98 super ().__init__ (* args , ** kwargs )
109 self .gen_functors = []
1110
12- def fgraph_convert (
13- self ,
14- fgraph ,
15- order ,
16- input_storage ,
17- output_storage ,
18- storage_map ,
19- ** kwargs ,
20- ):
11+ def fgraph_convert (self , fgraph , ** kwargs ):
2112 """Convert a PyTensor FunctionGraph to an MLX-compatible function.
2213
2314 Parameters
2415 ----------
2516 fgraph : FunctionGraph
2617 The function graph to convert
27- order : list
28- The order in which to compute the nodes
29- input_storage : list
30- Storage for the input variables
31- output_storage : list
32- Storage for the output variables
33- storage_map : dict
34- Map from variables to their storage
3518
3619 Returns
3720 -------
@@ -40,27 +23,9 @@ def fgraph_convert(
4023 """
4124 from pytensor .link .mlx .dispatch import mlx_funcify
4225
43- # We want to have globally unique names
44- # across the entire pytensor graph, not
45- # just the subgraph
46- generator = unique_name_generator (["mlx_linker" ])
47-
48- def conversion_func_register (* args , ** kwargs ):
49- functor = mlx_funcify (* args , ** kwargs )
50- name = kwargs ["unique_name" ](functor )
51- self .gen_functors .append ((f"_{ name } " , functor ))
52- return functor
53-
54- built_kwargs = {
55- "unique_name" : generator ,
56- "conversion_func" : conversion_func_register ,
57- ** kwargs ,
58- }
5926 return mlx_funcify (
6027 fgraph ,
61- input_storage = input_storage ,
62- storage_map = storage_map ,
63- ** built_kwargs ,
28+ ** kwargs ,
6429 )
6530
6631 def jit_compile (self , fn ):
0 commit comments