11import numpy as np
22
33from devito .arch .compiler import Compiler
4- from devito .ir import Callable , FindSymbols , SymbolRegistry
4+ from devito .ir import Callable , SymbolRegistry
5+ from devito .ir .iet .utils import has_dtype
56from devito .passes .iet .engine import iet_pass
67from devito .passes .iet .langbase import LangBB
78from devito .tools import as_tuple
1011
1112
1213@iet_pass
13- def _complex_includes (iet : Callable , lang : type [LangBB ], compiler : Compiler ,
14+ def _complex_includes (iet : Callable , langbb : type [LangBB ], compiler : Compiler ,
1415 sregistry : SymbolRegistry ) -> tuple [Callable , dict ]:
1516 """
1617 Includes complex arithmetic headers for the given language, if needed.
1718 """
1819 # Check if there are complex numbers that always take dtype precedence
19- for f in FindSymbols ().visit (iet ):
20- try :
21- if np .issubdtype (f .dtype , np .complexfloating ):
22- break
23- except TypeError :
24- continue
25- else :
20+ if not has_dtype (iet , np .complexfloating ):
2621 return iet , {}
2722
2823 metadata = {}
29- lib = as_tuple (lang ['includes-complex' ])
24+ lib = as_tuple (langbb ['includes-complex' ])
3025
31- if lang .get ('complex-namespace' ) is not None :
32- metadata ['namespaces' ] = lang ['complex-namespace' ]
26+ if langbb .get ('complex-namespace' ) is not None :
27+ metadata ['namespaces' ] = langbb ['complex-namespace' ]
3328
3429 # Some languges such as c++11 need some extra arithmetic definitions
35- if lang .get ('def-complex' ):
30+ if langbb .get ('def-complex' ):
3631 dest = compiler .get_jit_dir ()
3732 hfile = dest .joinpath ('complex_arith.h' )
3833 with open (str (hfile ), 'w' ) as ff :
39- ff .write (str (lang ['def-complex' ]))
34+ ff .write (str (langbb ['def-complex' ]))
4035 lib += (str (hfile ),)
4136
4237 metadata ['includes' ] = lib
@@ -47,12 +42,14 @@ def _complex_includes(iet: Callable, lang: type[LangBB], compiler: Compiler,
4742dtype_passes = [_complex_includes ]
4843
4944
50- def lower_dtypes (graph : Callable , lang : type [LangBB ] = None , compiler : Compiler = None ,
45+ def lower_dtypes (graph : Callable ,
46+ langbb : type [LangBB ] = None ,
47+ compiler : Compiler = None ,
5148 sregistry : SymbolRegistry = None , ** kwargs ) -> tuple [Callable , dict ]:
5249 """
5350 Lowers float16 scalar types to pointers since we can't directly pass their
5451 value. Also includes headers for complex arithmetic if needed.
5552 """
5653
5754 for dtype_pass in dtype_passes :
58- dtype_pass (graph , lang = lang , compiler = compiler , sregistry = sregistry )
55+ dtype_pass (graph , langbb = langbb , compiler = compiler , sregistry = sregistry )
0 commit comments