Skip to content

Commit 298de04

Browse files
majosminducer
authored andcommitted
add transform_dag implementation to pytato JAX array context in order to inline functions
1 parent 300da3c commit 298de04

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

arraycontext/impl/pytato/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,8 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
852852
An arraycontext that uses :mod:`pytato` to represent the thawed state of
853853
the arrays and compiles the expressions using
854854
:class:`pytato.target.python.JAXPythonTarget`.
855+
856+
.. automethod:: transform_dag
855857
"""
856858

857859
def __init__(self,
@@ -984,6 +986,13 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
984986
from .compile import LazilyJAXCompilingFunctionCaller
985987
return LazilyJAXCompilingFunctionCaller(self, f)
986988

989+
def transform_dag(self, dag: pytato.DictOfNamedArrays
990+
) -> pytato.DictOfNamedArrays:
991+
import pytato as pt
992+
dag = pt.tag_all_calls_to_be_inlined(dag)
993+
dag = pt.inline_calls(dag)
994+
return dag
995+
987996
@override
988997
def tag(self,
989998
tags: ToTagSetConvertible,

0 commit comments

Comments
 (0)