Skip to content

Commit ecfe652

Browse files
Google DeepMindcopybara-github
authored andcommitted
Add nullable attribute to AST nodes for function parameters in API introspection. Fixes #309
This change introduces an `nullable` flag to `ArrayType`, `PointerType`, and `FunctionParameterDecl`. The code generation process now parses function comments from `mujoco.h` to identify parameters marked as "Nullable" and sets this flag accordingly. PiperOrigin-RevId: 797870883 Change-Id: Ic7cac5a7a9c14177fbf97b8655c5c4809d4ef683
1 parent 3dec35a commit ecfe652

File tree

5 files changed

+146
-8
lines changed

5 files changed

+146
-8
lines changed

python/mujoco/introspect/ast_nodes.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,11 @@ class ValueType:
6262
name: str
6363
is_const: bool = False
6464
is_volatile: bool = False
65+
nullable: bool = False
6566

6667
def __init__(self, name: str, is_const: bool = False,
67-
is_volatile: bool = False):
68+
is_volatile: bool = False,
69+
nullable: bool = False):
6870
is_valid_type_name = (
6971
name == 'void *(*)(void *)' or
7072
VALID_TYPE_NAME_PATTERN.fullmatch(name) or
@@ -74,6 +76,7 @@ def __init__(self, name: str, is_const: bool = False,
7476
self.name = name
7577
self.is_const = is_const
7678
self.is_volatile = is_volatile
79+
self.nullable = nullable
7780

7881
def decl(self, name_or_decl: Optional[str] = None) -> str:
7982
parts = []
@@ -96,9 +99,11 @@ class ArrayType:
9699

97100
inner_type: Union[ValueType, 'PointerType']
98101
extents: Tuple[int, ...]
102+
nullable: bool = False
99103

100-
def __init__(self, inner_type: Union[ValueType, 'PointerType'],
101-
extents: Sequence[int]):
104+
def __init__(
105+
self, inner_type: Union[ValueType, 'PointerType'], extents: Sequence[int]
106+
):
102107
self.inner_type = inner_type
103108
self.extents = tuple(extents)
104109

@@ -119,13 +124,16 @@ class PointerType:
119124
"""Represents a C pointer type."""
120125

121126
inner_type: Union[ValueType, ArrayType, 'PointerType']
127+
nullable: bool = False
122128
is_const: bool = False
123129
is_volatile: bool = False
124130
is_restrict: bool = False
125131

126132
def decl(self, name_or_decl: Optional[str] = None) -> str:
127133
"""Creates a string that declares an object of this type."""
128134
parts = ['*']
135+
if self.nullable:
136+
parts.append('nullable')
129137
if self.is_const:
130138
parts.append('const')
131139
if self.is_volatile:
@@ -155,6 +163,7 @@ class FunctionParameterDecl:
155163

156164
name: str
157165
type: Union[ValueType, ArrayType, PointerType]
166+
nullable: bool = False
158167

159168
def __str__(self):
160169
return self.type.decl(self.name)

python/mujoco/introspect/ast_nodes_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,29 @@ def test_anonymous_union_decl(self):
174174
self.assertEqual(str(union), 'union {int foo; float bar[3];}')
175175
self.assertEqual(union.decl('var'), 'union {int foo; float bar[3];} var')
176176

177+
def test_function_parameter_decl_nullable(self):
178+
param_ptr_nullable = ast_nodes.FunctionParameterDecl(
179+
name='ptr_param',
180+
type=ast_nodes.PointerType(ast_nodes.ValueType('int')),
181+
nullable=True
182+
)
183+
self.assertTrue(param_ptr_nullable.nullable)
184+
self.assertEqual(str(param_ptr_nullable), 'int * ptr_param')
185+
186+
param_ptr_not_nullable = ast_nodes.FunctionParameterDecl(
187+
name='ptr_param',
188+
type=ast_nodes.PointerType(ast_nodes.ValueType('int')),
189+
)
190+
self.assertFalse(param_ptr_not_nullable.nullable)
191+
self.assertEqual(str(param_ptr_not_nullable), 'int * ptr_param')
192+
193+
param_array_nullable = ast_nodes.FunctionParameterDecl(
194+
name='array_param',
195+
type=ast_nodes.ArrayType(ast_nodes.ValueType('float'), (10,)),
196+
nullable=True
197+
)
198+
self.assertTrue(param_array_nullable.nullable)
199+
self.assertEqual(str(param_array_nullable), 'float array_param[10]')
177200

178201
if __name__ == '__main__':
179202
absltest.main()

python/mujoco/introspect/codegen/generate_functions.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,21 +58,43 @@ def _make_function(self, node: ClangJsonNode) -> ast_nodes.FunctionDecl:
5858
node['type']['qualType'])
5959
parameters = []
6060
comments = []
61+
nullable_params = set()
62+
6163
for child in node['inner']:
6264
child_kind = child.get('kind')
63-
if child_kind == 'ParmVarDecl':
64-
parameters.append(self._make_parameter(child))
6565
if child_kind == 'FullComment':
6666
comments.append(self._make_comment(child))
67+
nullable_params.update(self._find_nullable_params(child))
6768
comment = ' '.join(comments).strip()
69+
70+
for child in node['inner']:
71+
child_kind = child.get('kind')
72+
if child_kind == 'ParmVarDecl':
73+
parameters.append(self._make_parameter(child, nullable_params))
74+
6875
return ast_nodes.FunctionDecl(
6976
name=name, return_type=return_type, parameters=parameters, doc=comment)
7077

78+
def _find_nullable_params(self, node: ClangJsonNode) -> set[str]:
79+
"""Finds the names of parameters that are marked as nullable."""
80+
nullable_params = set()
81+
for child in node['inner']:
82+
child_kind = child.get('kind')
83+
if child_kind == 'ParagraphComment':
84+
nullable_params.update(self._find_nullable_params(child))
85+
if child_kind == 'TextComment':
86+
if 'Nullable' in child['text']:
87+
for param in child['text'].split(':')[1].split(','):
88+
nullable_params.add(param.strip())
89+
return nullable_params
90+
7191
def _make_parameter(
72-
self, node: ClangJsonNode) -> ast_nodes.FunctionParameterDecl:
92+
self, node: ClangJsonNode, nullable_params: set[str]
93+
) -> ast_nodes.FunctionParameterDecl:
7394
"""Makes a ParameterDecl from a Clang AST ParmVarDecl node."""
7495
name = node['name']
7596
type_name = node['type']['qualType']
97+
nullable = name in nullable_params
7698

7799
# For a pointer parameters, look up in the original header to see if
78100
# n array extent was declared there.
@@ -85,7 +107,10 @@ def _make_parameter(
85107
type_name = decl[:name_begin] + decl[name_end:]
86108

87109
return ast_nodes.FunctionParameterDecl(
88-
name=name, type=type_parsing.parse_type(type_name))
110+
nullable=nullable,
111+
name=name,
112+
type=type_parsing.parse_type(type_name),
113+
)
89114

90115
def _make_comment(self, node: ClangJsonNode) -> str:
91116
"""Makes a comment string from a Clang AST FullComment node."""

0 commit comments

Comments
 (0)