@@ -121,8 +121,6 @@ class NodeModelRegistry:
121121 """Registry for AiiDA REST API node models.
122122
123123 This class maintains mappings of node types and their corresponding Pydantic models.
124-
125- :ivar ModelUnion: A union type of all AiiDA node Pydantic models, discriminated by the `node_type` field.
126124 """
127125
128126 def __init__ (self ) -> None :
@@ -136,22 +134,30 @@ def get_node_types(self) -> list[str]:
136134 """Get the list of registered node class names.
137135
138136 :return: List of node class names.
137+ :rtype: list[str]
139138 """
140139 return list (self ._models .keys ())
141140
142141 def get_node_class_name (self , node_type : str ) -> str :
143142 """Get the AiiDA node class name for a given node type.
144143
145144 :param node_type: The AiiDA node type string.
145+ :type node_type: str
146146 :return: The corresponding node class name.
147+ :rtype: str
147148 """
148149 return node_type .rsplit ('.' , 2 )[- 2 ]
149150
150151 def get_model (self , node_type : str , which : t .Literal ['get' , 'post' ] = 'get' ) -> type [Node .Model ]:
151152 """Get the Pydantic model class for a given node type.
152153
153154 :param node_type: The AiiDA node type string.
155+ :type node_type: str
156+ :param which: Specify whether to get the 'get' or 'post' model.
154157 :return: The corresponding Pydantic model class.
158+ :rtype: type[Node.Model]
159+ :raises MissingEntryPointError: If the node type is not registered.
160+ :raises KeyError: If the specified model type is unknown.
155161 """
156162 if (Model := self ._models .get (node_type )) is None :
157163 raise MissingEntryPointError (f'Unknown node type: { node_type } ' )
@@ -163,7 +169,9 @@ def _get_node_post_model(self, node_cls: Node) -> type[Node.Model]:
163169 """Return a patched Model for the given node class with a literal `node_type` field.
164170
165171 :param node_cls: The AiiDA node class.
172+ :type node_cls: Node
166173 :return: The patched ORM Node model.
174+ :rtype: type[Node.Model]
167175 """
168176 Model = node_cls .CreateModel
169177 node_type = node_cls .class_node_type
@@ -197,6 +205,7 @@ def _get_post_models(self) -> tuple[type[Node.Model], ...]:
197205 """Get a union type of all node 'post' models.
198206
199207 :return: A union type of all node 'post' models.
208+ :rtype: tuple[type[Node.Model], ...]
200209 """
201210 post_models = [model_dict ['post' ] for model_dict in self ._models .values ()]
202211 return tuple (post_models )
0 commit comments