Skip to content

Commit 952eaa8

Browse files
committed
partial support for intrinsic functions
1 parent 106fcab commit 952eaa8

File tree

2 files changed

+328
-71
lines changed

2 files changed

+328
-71
lines changed

open_fortran_parser/ast_transformer.py

Lines changed: 234 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,13 @@ def _argument(self, node: ET.Element) -> typed_ast3.arg:
152152
if 'name' not in node.attrib:
153153
raise SyntaxError(
154154
'"name" attribute not present in:\n{}'.format(ET.tostring(node).decode().rstrip()))
155-
self._ensure_top_level_import('typing', 't')
156-
return typed_ast3.arg(
157-
arg=node.attrib['name'], annotation=typed_ast3.Attribute(
158-
value=typed_ast3.Name(id='t', ctx=typed_ast3.Load()),
159-
attr='Any', ctx=typed_ast3.Load()))
155+
values = self.transform_all_subnodes(
156+
node, warn=False, skip_empty=False,
157+
ignored={'actual-arg', 'actual-arg-spec', 'dummy-arg'})
158+
if values:
159+
assert len(values) == 1
160+
return typed_ast3.keyword(arg=node.attrib['name'], value=values[0])
161+
return typed_ast3.arg(arg=node.attrib['name'], annotation=None)
160162

161163
def _program(self, node: ET.Element) -> typed_ast3.AST:
162164
module = typed_ast3.parse('''if __name__ == '__main__':\n pass''')
@@ -251,7 +253,7 @@ def _declaration_variable(
251253
for i, (var, val) in enumerate(variables):
252254
val = typed_ast3.Call(
253255
func=typed_ast3.Attribute(
254-
value=typed_ast3.Name(id='np'), attr='ndarray', ctx=typed_ast3.Load()),
256+
value=typed_ast3.Name(id='np'), attr='zeros', ctx=typed_ast3.Load()),
255257
args=[typed_ast3.Tuple(elts=dimensions)],
256258
keywords=[typed_ast3.keyword(arg='dtype', value=data_type)])
257259
variables[i] = (var, val)
@@ -287,8 +289,8 @@ def _loop(self, node: ET.Element):
287289
elif node.attrib['type'] == 'forall':
288290
return self._loop_forall(node)
289291
else:
290-
_LOG.error('%s', ET.tostring(node).decode().rstrip())
291-
raise NotImplementedError()
292+
raise NotImplementedError(
293+
'not implemented handling of:\n{}'.format(ET.tostring(node).decode().rstrip()))
292294

293295
def _loop_do(self, node: ET.Element) -> typed_ast3.For:
294296
index_variable = node.find('./header/index-variable')
@@ -346,7 +348,6 @@ def _index_variable(self, node: ET.Element) -> t.Tuple[typed_ast3.Name, typed_as
346348
return target, iter_
347349

348350
def _if(self, node: ET.Element):
349-
#_LOG.warning('if header:')
350351
header = self.transform_all_subnodes(
351352
node.find('./header'), warn=False,
352353
ignored={'executable-construct', 'execution-part-construct'})
@@ -382,23 +383,49 @@ def _statement(self, node: ET.Element):
382383
else typed_ast3.Expr(value=detail)
383384
for detail in details]
384385

386+
def _allocations(self, node: ET.Element) -> typed_ast3.Assign:
387+
allocation_nodes = node.findall('./allocation')
388+
allocations = []
389+
for allocation_node in allocation_nodes:
390+
if not allocation_node:
391+
continue
392+
allocation = self.transform_all_subnodes(allocation_node, warn=False)
393+
assert len(allocation) == 1
394+
allocations.append(allocation[0])
395+
assert len(allocations) == int(node.attrib['count']), (len(allocations), node.attrib['count'])
396+
assignments = []
397+
for allocation in allocations:
398+
assert isinstance(allocation, typed_ast3.Subscript)
399+
var = allocation.value
400+
if isinstance(allocation.slice, typed_ast3.Index):
401+
sizes = [allocation.slice.value]
402+
elif isinstance(allocation.slice, typed_ast3.ExtSlice):
403+
sizes = allocation.slice.dims
404+
else:
405+
raise NotImplementedError('unrecognized slice type: "{}"'.format(type(allocation.slice)))
406+
val = typed_ast3.Call(
407+
func=typed_ast3.Attribute(
408+
value=typed_ast3.Name(id='np'), attr='zeros', ctx=typed_ast3.Load()),
409+
args=[typed_ast3.Tuple(elts=sizes)], keywords=[typed_ast3.keyword(arg='dtype', value='t.Any')])
410+
assignments.append(
411+
typed_ast3.Assign(targets=[var], value=val, type_comment=None))
412+
return assignments
413+
385414
def _call(self, node: ET.Element) -> t.Union[typed_ast3.Call, typed_ast3.Assign]:
386415
called = self.transform_all_subnodes(node, warn=False, ignored={'call-stmt'})
387416
if len(called) != 1:
388-
_LOG.warning('%s', ET.tostring(node).decode().rstrip())
389-
_LOG.error('%s', [typed_astunparse.unparse(_).rstrip() for _ in called])
390-
raise SyntaxError("call statement must contain a single called object")
391-
if isinstance(called[0], typed_ast3.Call):
392-
call = called[0]
393-
else:
394-
_LOG.warning('called an ambiguous node')
395-
_LOG.warning('%s', ET.tostring(node).decode().rstrip())
396-
func = called[0]
397-
#assert name.tag == 'name' or name.
398-
args = []
399-
#if isinstance(name, typed_ast3.Subscript):
400-
#args = node.findall('./name/subscripts/subscript')
401-
call = typed_ast3.Call(func=func, args=args, keywords=[])
417+
raise SyntaxError(
418+
'call statement must contain a single called object, not {}, like in:\n{}'.format(
419+
[typed_astunparse.unparse(_).rstrip() for _ in called],
420+
ET.tostring(node).decode().rstrip()))
421+
call = called[0]
422+
if not isinstance(call, typed_ast3.Call):
423+
name_node = node.find('./name')
424+
is_intrinsic = name_node.attrib['id'] in self._intrinsics_converters if name_node is not None else False
425+
if is_intrinsic:
426+
return call
427+
_LOG.warning('called an ambiguous node:\n%s', ET.tostring(node).decode().rstrip())
428+
call = typed_ast3.Call(func=call, args=[], keywords=[])
402429
if isinstance(call.func, typed_ast3.Name) and call.func.id.startswith('MPI_'):
403430
call = self._transform_mpi_call(call)
404431
return call
@@ -693,6 +720,24 @@ def _array_constructor_values(self, node: ET.Element) -> typed_ast3.List:
693720

694721
return typed_ast3.List(elts=values, ctx=typed_ast3.Load())
695722

723+
def _range(self, node: ET.Element) -> typed_ast3.Slice:
724+
lower_bound = node.find('./lower-bound')
725+
upper_bound = node.find('./upper-bound')
726+
step = node.find('./step')
727+
if lower_bound is not None:
728+
args = self.transform_all_subnodes(lower_bound)
729+
assert len(args) == 1, args
730+
lower_bound = args[0]
731+
if upper_bound is not None:
732+
args = self.transform_all_subnodes(upper_bound)
733+
assert len(args) == 1, args
734+
upper_bound = args[0]
735+
if step is not None:
736+
args = self.transform_all_subnodes(step)
737+
assert len(args) == 1, args
738+
step = args[0]
739+
return typed_ast3.Slice(lower=lower_bound, upper=upper_bound, step=step)
740+
696741
def _dimension(self, node: ET.Element) -> t.Union[typed_ast3.Index, typed_ast3.Slice]:
697742
dim_type = node.attrib['type']
698743
if dim_type == 'simple':
@@ -701,27 +746,7 @@ def _dimension(self, node: ET.Element) -> t.Union[typed_ast3.Index, typed_ast3.S
701746
_LOG.error('simple dimension should have exactly one value, but it has %i', len(values))
702747
return typed_ast3.Index(value=values[0])
703748
elif dim_type == 'range':
704-
lower_bound = node.find('./lower-bound')
705-
upper_bound = node.find('./upper-bound')
706-
step = node.find('./step')
707-
#range_args = []
708-
if lower_bound is not None:
709-
args = self.transform_all_subnodes(lower_bound)
710-
assert len(args) == 1, args
711-
#range_args.append(args[0])
712-
lower_bound = args[0]
713-
if upper_bound is not None:
714-
args = self.transform_all_subnodes(upper_bound)
715-
assert len(args) == 1, args
716-
#range_args.append(typed_ast3.BinOp(
717-
# left=args[0], op=typed_ast3.Add(), right=typed_ast3.Num(n=1)))
718-
upper_bound = args[0]
719-
if step is not None:
720-
args = self.transform_all_subnodes(step)
721-
assert len(args) == 1, args
722-
#range_args.append(args[0])
723-
step = args[0]
724-
return typed_ast3.Slice(lower=lower_bound, upper=upper_bound, step=step)
749+
return self._range(node)
725750
elif dim_type == 'assumed-shape':
726751
return typed_ast3.Slice(lower=None, upper=None, step=None)
727752
else:
@@ -753,7 +778,7 @@ def _type(self, node: ET.Element) -> type:
753778
if length is not None:
754779
if isinstance(length, typed_ast3.Num):
755780
length = length.n
756-
_LOG.warning(
781+
_LOG.info(
757782
'ignoring string length "%i" in:\n%s',
758783
length, ET.tostring(node).decode().rstrip())
759784
return typed_ast3.parse(self._basic_types[name, t.Any], mode='eval')
@@ -814,37 +839,180 @@ def _variable(self, node: ET.Element) -> t.Tuple[
814839
def _names(self, node: ET.Element) -> typed_ast3.arguments:
815840
return self._arguments(node)
816841

842+
def _intrinsic_identity(self, call):
843+
return call
844+
845+
def _intrinsic_getenv(self, call):
846+
assert isinstance(call, typed_ast3.Call), type(call)
847+
assert len(call.args) == 2, call.args
848+
self._ensure_top_level_import('os')
849+
target = call.args[1]
850+
if isinstance(target, typed_ast3.keyword):
851+
target = target.value
852+
return typed_ast3.Assign(
853+
targets=[target],
854+
value=typed_ast3.Subscript(
855+
value=typed_ast3.Attribute(value=typed_ast3.Name(id='os', ctx=typed_ast3.Load()),
856+
attr='environ', ctx=typed_ast3.Load()),
857+
slice=typed_ast3.Index(value=call.args[0]), ctx=typed_ast3.Load())
858+
, type_comment=None)
859+
860+
def _intrinsic_count(self, call):
861+
assert isinstance(call, typed_ast3.Call), type(call)
862+
assert len(call.args) == 1, call.args
863+
return typed_ast3.Call(
864+
func=typed_ast3.Attribute(value=call.args[0], attr='sum', ctx=typed_ast3.Load()),
865+
args=[], keywords=[])
866+
867+
def _intrinsic_converter_not_implemented(self, call):
868+
raise NotImplementedError(
869+
"cannot convert intrinsic call from raw AST:\n{}"
870+
.format(typed_astunparse.unparse(call)))
871+
872+
_intrinsics_converters = {
873+
# Fortran 77
874+
'abs': _intrinsic_identity, # np.absolute
875+
'acos': _intrinsic_converter_not_implemented,
876+
'aimag': _intrinsic_converter_not_implemented,
877+
'aint': _intrinsic_converter_not_implemented,
878+
'anint': _intrinsic_converter_not_implemented,
879+
'asin': _intrinsic_converter_not_implemented,
880+
'atan': _intrinsic_converter_not_implemented,
881+
'atan2': _intrinsic_converter_not_implemented,
882+
'char': _intrinsic_converter_not_implemented,
883+
'cmplx': _intrinsic_converter_not_implemented,
884+
'conjg': _intrinsic_converter_not_implemented,
885+
'cos': _intrinsic_converter_not_implemented,
886+
'cosh': _intrinsic_converter_not_implemented,
887+
'dble': _intrinsic_converter_not_implemented,
888+
'dim': _intrinsic_converter_not_implemented,
889+
'dprod': _intrinsic_converter_not_implemented,
890+
'exp': _intrinsic_converter_not_implemented,
891+
'ichar': _intrinsic_converter_not_implemented,
892+
'index': _intrinsic_converter_not_implemented,
893+
'int': _intrinsic_identity,
894+
'len': _intrinsic_converter_not_implemented,
895+
'lge': _intrinsic_converter_not_implemented,
896+
'lgt': _intrinsic_converter_not_implemented,
897+
'lle': _intrinsic_converter_not_implemented,
898+
'llt': _intrinsic_converter_not_implemented,
899+
'log': _intrinsic_converter_not_implemented,
900+
'log10': _intrinsic_converter_not_implemented,
901+
'max': _intrinsic_converter_not_implemented,
902+
'min': _intrinsic_converter_not_implemented,
903+
'mod': _intrinsic_converter_not_implemented,
904+
'nint': _intrinsic_converter_not_implemented,
905+
'real': _intrinsic_converter_not_implemented,
906+
'sign': _intrinsic_converter_not_implemented,
907+
'sin': _intrinsic_converter_not_implemented,
908+
'sinh': _intrinsic_converter_not_implemented,
909+
'sqrt': _intrinsic_converter_not_implemented,
910+
'tan': _intrinsic_converter_not_implemented,
911+
'tanh': _intrinsic_converter_not_implemented,
912+
# non-standard Fortran 77
913+
'getenv': _intrinsic_getenv,
914+
# Fortran 90
915+
# Character string functions
916+
'achar': _intrinsic_converter_not_implemented,
917+
'adjustl': _intrinsic_converter_not_implemented,
918+
'adjustr': _intrinsic_converter_not_implemented,
919+
'iachar': _intrinsic_converter_not_implemented,
920+
'len_trim': _intrinsic_converter_not_implemented,
921+
'repeat': _intrinsic_converter_not_implemented,
922+
'scan': _intrinsic_converter_not_implemented,
923+
'trim': lambda self, call: typed_ast3.Call(
924+
func=typed_ast3.Attribute(value=call.args[0], attr='rstrip', ctx=typed_ast3.Load()),
925+
args=call.args[1:], keywords=[]),
926+
'verify': _intrinsic_converter_not_implemented,
927+
# Logical function
928+
'logical': _intrinsic_converter_not_implemented,
929+
# Numerical inquiry functions
930+
'digits': _intrinsic_converter_not_implemented,
931+
'epsilon': _intrinsic_converter_not_implemented,
932+
'huge': _intrinsic_converter_not_implemented,
933+
'maxexponent': _intrinsic_converter_not_implemented,
934+
'minexponent': _intrinsic_converter_not_implemented,
935+
'precision': _intrinsic_converter_not_implemented,
936+
'radix': _intrinsic_converter_not_implemented,
937+
'range': _intrinsic_converter_not_implemented,
938+
'tiny': _intrinsic_converter_not_implemented,
939+
# Bit inquiry function
940+
'bit_size': _intrinsic_converter_not_implemented,
941+
# Vector- and matrix-multiplication functions
942+
'dot_product': _intrinsic_converter_not_implemented,
943+
'matmul': _intrinsic_converter_not_implemented,
944+
# Array functions
945+
'all': _intrinsic_converter_not_implemented,
946+
'any': _intrinsic_converter_not_implemented,
947+
'count': _intrinsic_count,
948+
'maxval': _intrinsic_converter_not_implemented,
949+
'minval': _intrinsic_converter_not_implemented,
950+
'product': _intrinsic_converter_not_implemented,
951+
'sum': _intrinsic_identity,
952+
# Array location functions
953+
'maxloc': lambda self, call: typed_ast3.Call(
954+
func=typed_ast3.Attribute(value=typed_ast3.Name(id='np', ctx=typed_ast3.Load()),
955+
attr='argmax', ctx=typed_ast3.Load()),
956+
args=call.args, keywords=call.keywords),
957+
'minloc': lambda self, call: typed_ast3.Call(
958+
func=typed_ast3.Attribute(value=typed_ast3.Name(id='np', ctx=typed_ast3.Load()),
959+
attr='argmin', ctx=typed_ast3.Load()),
960+
args=call.args, keywords=call.keywords),
961+
# Fortran 95
962+
'cpu_time': _intrinsic_converter_not_implemented,
963+
'present': _intrinsic_converter_not_implemented,
964+
'set_exponent': _intrinsic_converter_not_implemented,
965+
# Fortran 2003
966+
# Fortran 2008
967+
}
968+
817969
def _name(self, node: ET.Element) -> typed_ast3.AST:
818-
name = typed_ast3.Name(id=node.attrib['id'], ctx=typed_ast3.Load())
819-
if 'type' in node.attrib:
820-
name_type = node.attrib['type']
821-
else:
822-
name_type = None
823-
#_LOG.warning('%s', ET.tostring(node).decode().rstrip())
824-
#raise NotImplementedError()
970+
name_str = node.attrib['id']
971+
name = typed_ast3.Name(id=name_str, ctx=typed_ast3.Load())
972+
name_str = name_str.lower()
973+
name_type = node.attrib['type'] if 'type' in node.attrib else None
974+
is_intrinsic = name_str in self._intrinsics_converters
975+
825976
subscripts_node = node.find('./subscripts')
826-
if name_type == "procedure":
977+
try:
827978
args = self._args(subscripts_node) if subscripts_node else []
828-
return typed_ast3.Call(func=name, args=args, keywords=[])
829-
if not subscripts_node:
979+
call = typed_ast3.Call(func=name, args=args, keywords=[])
980+
if is_intrinsic:
981+
name_type = "function"
982+
call = self._intrinsics_converters[name_str](self, call)
983+
except SyntaxError:
984+
_LOG.info('transforming name to call failed as below (continuing despite that)', exc_info=True)
985+
986+
slice_ = self._subscripts(subscripts_node) if subscripts_node else None
987+
subscript = typed_ast3.Subscript(value=name, slice=slice_, ctx=typed_ast3.Load())
988+
989+
if name_type in ("procedure", "function"):
990+
return call
991+
elif not subscripts_node:
830992
return name
831-
slice_ = self._subscripts(subscripts_node)
832-
if not slice_:
833-
return typed_ast3.Call(func=name, args=[], keywords=[])
834-
return typed_ast3.Subscript(value=name, slice=slice_, ctx=typed_ast3.Load())
993+
elif name_type in ("variable",):
994+
return subscript
995+
elif not slice_:
996+
return call
997+
elif name_type in ("ambiguous",):
998+
return subscript
999+
elif name_type is not None:
1000+
raise NotImplementedError('unrecognized name type "{}" in:\n{}'.format(name_type, ET.tostring(node).decode().rstrip()))
1001+
elif name_type is None:
1002+
raise NotImplementedError('no name type in:\n{}'.format(ET.tostring(node).decode().rstrip()))
1003+
raise NotImplementedError()
8351004

8361005
def _args(self, node: ET.Element, arg_node_name: str = 'subscript') -> t.List[typed_ast3.AST]:
8371006
args = []
8381007
for arg_node in node.findall(f'./{arg_node_name}'):
839-
new_args = self.transform_all_subnodes(
840-
arg_node, warn=False, skip_empty=True,
841-
ignored={'section-subscript', 'actual-arg', 'actual-arg-spec', 'argument'})
1008+
new_args = self.transform_all_subnodes(arg_node, warn=False, skip_empty=True)
8421009
if not new_args:
8431010
continue
8441011
if len(new_args) != 1:
845-
_LOG.error('%s', ET.tostring(arg_node).decode().rstrip())
846-
_LOG.error('%s', [typed_astunparse.unparse(_) for _ in new_args])
847-
raise SyntaxError('args must be specified one new arg at a time')
1012+
raise SyntaxError(
1013+
'args must be specified one new arg at a time, not like {} in:\n{}'.format(
1014+
[typed_astunparse.unparse(_) for _ in new_args],
1015+
ET.tostring(arg_node).decode().rstrip()))
8481016
args += new_args
8491017
return args
8501018

@@ -853,8 +1021,7 @@ def _subscripts(
8531021
typed_ast3.Index, typed_ast3.Slice, typed_ast3.ExtSlice]:
8541022
subscripts = []
8551023
for subscript in node.findall('./subscript'):
856-
new_subscripts = self.transform_all_subnodes(
857-
subscript, warn=False, ignored={'section-subscript'})
1024+
new_subscripts = self.transform_all_subnodes(subscript, warn=False)
8581025
if not new_subscripts:
8591026
continue
8601027
if len(new_subscripts) == 1:

0 commit comments

Comments
 (0)