1111#
1212# SPDX-License-Identifier: GPL-3.0-or-later
1313
14- from dataclasses import dataclass
15- from typing import Any , Callable , Optional , Protocol , Type
14+ from typing import Any , Optional , Protocol , Type
1615
16+ import icon4py .liskov .parsing .parse
1717import icon4py .liskov .parsing .types as ts
18- from icon4py .liskov .codegen .interface import (
18+ from icon4py .common .logger import setup_logger
19+ from icon4py .liskov .codegen .integration .interface import (
1920 BoundsData ,
20- CodeGenInput ,
2121 DeclareData ,
22- DeserialisedDirectives ,
2322 EndCreateData ,
2423 EndIfData ,
2524 EndProfileData ,
2625 EndStencilData ,
2726 FieldAssociationData ,
2827 ImportsData ,
2928 InsertData ,
29+ IntegrationCodeInterface ,
3030 StartCreateData ,
3131 StartProfileData ,
3232 StartStencilData ,
3333 UnusedDirective ,
3434)
35- from icon4py .liskov .common import Step
36- from icon4py .liskov .logger import setup_logger
35+ from icon4py .liskov .codegen . shared . deserialise import Deserialiser
36+ from icon4py .liskov .codegen . shared . types import CodeGenInput
3737from icon4py .liskov .parsing .exceptions import (
3838 DirectiveSyntaxError ,
3939 MissingBoundsError ,
@@ -89,13 +89,11 @@ def __call__(
8989 ...
9090
9191
92- @dataclass
9392class DataFactoryBase :
9493 directive_cls : Type [ts .ParsedDirective ]
9594 dtype : Type [CodeGenInput ]
9695
9796
98- @dataclass
9997class OptionalMultiUseDataFactory (DataFactoryBase ):
10098 def __call__ (
10199 self , parsed : ts .ParsedDict , ** kwargs : Any
@@ -106,48 +104,38 @@ def __call__(
106104 else :
107105 deserialised = []
108106 for directive in extracted :
109- deserialised .append (
110- self .dtype (
111- startln = directive .startln , endln = directive .endln , ** kwargs
112- )
113- )
107+ deserialised .append (self .dtype (startln = directive .startln , ** kwargs ))
114108 return deserialised
115109
116110
117- @dataclass
118111class RequiredSingleUseDataFactory (DataFactoryBase ):
119112 def __call__ (self , parsed : ts .ParsedDict ) -> CodeGenInput :
120113 extracted = extract_directive (parsed ["directives" ], self .directive_cls )[0 ]
121- return self .dtype (startln = extracted .startln , endln = extracted . endln )
114+ return self .dtype (startln = extracted .startln )
122115
123116
124- @dataclass
125117class EndCreateDataFactory (RequiredSingleUseDataFactory ):
126- directive_cls : Type [ts .ParsedDirective ] = ts .EndCreate
118+ directive_cls : Type [ts .ParsedDirective ] = icon4py . liskov . parsing . parse .EndCreate
127119 dtype : Type [EndCreateData ] = EndCreateData
128120
129121
130- @dataclass
131122class ImportsDataFactory (RequiredSingleUseDataFactory ):
132- directive_cls : Type [ts .ParsedDirective ] = ts .Imports
123+ directive_cls : Type [ts .ParsedDirective ] = icon4py . liskov . parsing . parse .Imports
133124 dtype : Type [ImportsData ] = ImportsData
134125
135126
136- @dataclass
137127class EndIfDataFactory (OptionalMultiUseDataFactory ):
138- directive_cls : Type [ts .ParsedDirective ] = ts .EndIf
128+ directive_cls : Type [ts .ParsedDirective ] = icon4py . liskov . parsing . parse .EndIf
139129 dtype : Type [EndIfData ] = EndIfData
140130
141131
142- @dataclass
143132class EndProfileDataFactory (OptionalMultiUseDataFactory ):
144- directive_cls : Type [ts .ParsedDirective ] = ts .EndProfile
133+ directive_cls : Type [ts .ParsedDirective ] = icon4py . liskov . parsing . parse .EndProfile
145134 dtype : Type [EndProfileData ] = EndProfileData
146135
147136
148- @dataclass
149137class StartCreateDataFactory (DataFactoryBase ):
150- directive_cls : Type [ts .ParsedDirective ] = ts .StartCreate
138+ directive_cls : Type [ts .ParsedDirective ] = icon4py . liskov . parsing . parse .StartCreate
151139 dtype : Type [StartCreateData ] = StartCreateData
152140
153141 def __call__ (self , parsed : ts .ParsedDict ) -> StartCreateData :
@@ -159,14 +147,11 @@ def __call__(self, parsed: ts.ParsedDict) -> StartCreateData:
159147 if named_args :
160148 extra_fields = named_args ["extra_fields" ].split ("," )
161149
162- return self .dtype (
163- startln = directive .startln , endln = directive .endln , extra_fields = extra_fields
164- )
150+ return self .dtype (startln = directive .startln , extra_fields = extra_fields )
165151
166152
167- @dataclass
168153class DeclareDataFactory (DataFactoryBase ):
169- directive_cls : Type [ts .ParsedDirective ] = ts .Declare
154+ directive_cls : Type [ts .ParsedDirective ] = icon4py . liskov . parsing . parse .Declare
170155 dtype : Type [DeclareData ] = DeclareData
171156
172157 @staticmethod
@@ -185,7 +170,6 @@ def __call__(self, parsed: ts.ParsedDict) -> list[DeclareData]:
185170 deserialised .append (
186171 self .dtype (
187172 startln = directive .startln ,
188- endln = directive .endln ,
189173 declarations = named_args ,
190174 ident_type = ident_type ,
191175 suffix = suffix ,
@@ -194,9 +178,8 @@ def __call__(self, parsed: ts.ParsedDict) -> list[DeclareData]:
194178 return deserialised
195179
196180
197- @dataclass
198181class StartProfileDataFactory (DataFactoryBase ):
199- directive_cls : Type [ts .ParsedDirective ] = ts .StartProfile
182+ directive_cls : Type [ts .ParsedDirective ] = icon4py . liskov . parsing . parse .StartProfile
200183 dtype : Type [StartProfileData ] = StartProfileData
201184
202185 def __call__ (self , parsed : ts .ParsedDict ) -> list [StartProfileData ]:
@@ -206,18 +189,13 @@ def __call__(self, parsed: ts.ParsedDict) -> list[StartProfileData]:
206189 named_args = parsed ["content" ]["StartProfile" ][i ]
207190 stencil_name = _extract_stencil_name (named_args , directive )
208191 deserialised .append (
209- self .dtype (
210- name = stencil_name ,
211- startln = directive .startln ,
212- endln = directive .endln ,
213- )
192+ self .dtype (name = stencil_name , startln = directive .startln )
214193 )
215194 return deserialised
216195
217196
218- @dataclass
219197class EndStencilDataFactory (DataFactoryBase ):
220- directive_cls : Type [ts .ParsedDirective ] = ts .EndStencil
198+ directive_cls : Type [ts .ParsedDirective ] = icon4py . liskov . parsing . parse .EndStencil
221199 dtype : Type [EndStencilData ] = EndStencilData
222200
223201 def __call__ (self , parsed : ts .ParsedDict ) -> list [EndStencilData ]:
@@ -232,17 +210,15 @@ def __call__(self, parsed: ts.ParsedDict) -> list[EndStencilData]:
232210 self .dtype (
233211 name = stencil_name ,
234212 startln = directive .startln ,
235- endln = directive .endln ,
236213 noendif = noendif ,
237214 noprofile = noprofile ,
238215 )
239216 )
240217 return deserialised
241218
242219
243- @dataclass
244220class StartStencilDataFactory (DataFactoryBase ):
245- directive_cls : Type [ts .ParsedDirective ] = ts .StartStencil
221+ directive_cls : Type [ts .ParsedDirective ] = icon4py . liskov . parsing . parse .StartStencil
246222 dtype : Type [StartStencilData ] = StartStencilData
247223
248224 def __call__ (self , parsed : ts .ParsedDict ) -> list [StartStencilData ]:
@@ -282,7 +258,6 @@ def __call__(self, parsed: ts.ParsedDict) -> list[StartStencilData]:
282258 fields = fields_w_tolerance ,
283259 bounds = bounds ,
284260 startln = directive .startln ,
285- endln = directive .endln ,
286261 acc_present = acc_present ,
287262 mergecopy = mergecopy ,
288263 copies = copies ,
@@ -377,9 +352,8 @@ def _update_tolerances(
377352 return fields
378353
379354
380- @dataclass
381355class InsertDataFactory (DataFactoryBase ):
382- directive_cls : Type [ts .ParsedDirective ] = ts .Insert
356+ directive_cls : Type [ts .ParsedDirective ] = icon4py . liskov . parsing . parse .Insert
383357 dtype : Type [InsertData ] = InsertData
384358
385359 def __call__ (self , parsed : ts .ParsedDict ) -> list [InsertData ]:
@@ -388,15 +362,13 @@ def __call__(self, parsed: ts.ParsedDict) -> list[InsertData]:
388362 for i , directive in enumerate (extracted ):
389363 content = parsed ["content" ]["Insert" ][i ]
390364 deserialised .append (
391- self .dtype (
392- startln = directive .startln , endln = directive .endln , content = content # type: ignore
393- )
365+ self .dtype (startln = directive .startln , content = content ) # type: ignore
394366 )
395367 return deserialised
396368
397369
398- class DirectiveDeserialiser ( Step ):
399- _FACTORIES : dict [ str , Callable ] = {
370+ class IntegrationCodeDeserialiser ( Deserialiser ):
371+ _FACTORIES = {
400372 "StartCreate" : StartCreateDataFactory (),
401373 "EndCreate" : EndCreateDataFactory (),
402374 "Imports" : ImportsDataFactory (),
@@ -408,27 +380,4 @@ class DirectiveDeserialiser(Step):
408380 "EndProfile" : EndProfileDataFactory (),
409381 "Insert" : InsertDataFactory (),
410382 }
411-
412- def __call__ (self , directives : ts .ParsedDict ) -> DeserialisedDirectives :
413- """Deserialise the provided parsed directives to a DeserialisedDirectives object.
414-
415- Args:
416- directives: The parsed directives to deserialise.
417-
418- Returns:
419- A DeserialisedDirectives object containing the deserialised directives.
420-
421- Note:
422- The method uses the `_FACTORIES` class attribute to create the appropriate
423- factory object for each directive type, and uses these objects to deserialise
424- the parsed directives. The DeserialisedDirectives class is a dataclass
425- containing the deserialised versions of the different directives.
426- """
427- logger .info ("Deserialising directives ..." )
428- deserialised = dict ()
429-
430- for key , func in self ._FACTORIES .items ():
431- ser = func (directives )
432- deserialised [key ] = ser
433-
434- return DeserialisedDirectives (** deserialised )
383+ _INTERFACE_TYPE = IntegrationCodeInterface
0 commit comments