Skip to content

Commit 3252eb2

Browse files
committed
Support parameter reference resolution in service generator
Enhances the Python service generator to resolve parameter references from OpenAPI components, ensuring referenced parameters are correctly included in generated service methods. Adds internal helper for reference resolution, updates function signatures to accept components, and introduces tests to verify correct parameter resolution for both OpenAPI 3.0 and 3.1.
1 parent f356cf0 commit 3252eb2

File tree

3 files changed

+274
-104
lines changed

3 files changed

+274
-104
lines changed

src/openapi_python_generator/language_converters/python/generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def generator(
4141
models = []
4242

4343
if data.paths is not None:
44-
services = generate_services(data.paths, library_config)
44+
services = generate_services(data.paths, library_config, data.components)
4545
else:
4646
services = []
4747

src/openapi_python_generator/language_converters/python/service_generator.py

Lines changed: 85 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
Schema as Schema31,
3838
)
3939
from openapi_pydantic.v3.v3_1.parameter import Parameter as Parameter31
40+
from openapi_pydantic.v3.v3_0 import Components as Components30
41+
from openapi_pydantic.v3.v3_1 import Components as Components31
4042

4143
from openapi_python_generator.language_converters.python import common
4244
from openapi_python_generator.language_converters.python.common import normalize_symbol
@@ -54,6 +56,32 @@
5456
TypeConversion,
5557
)
5658

59+
# Type alias for Components
60+
Components = Union[Components30, Components31]
61+
62+
# Module-level storage for component parameters (set by generate_services)
63+
_component_params: Optional[Dict[str, Union[Parameter30, Parameter31]]] = None
64+
65+
66+
def _resolve_parameter_ref(
67+
param: Union[Parameter30, Parameter31, Reference30, Reference31],
68+
) -> Optional[Union[Parameter30, Parameter31]]:
69+
if isinstance(param, (Parameter30, Parameter31)):
70+
return param
71+
72+
if isinstance(param, (Reference, Reference30, Reference31)):
73+
if _component_params is None:
74+
return None
75+
# Extract parameter name from $ref like "#/components/parameters/LangParameter"
76+
ref_str = getattr(param, "ref", None)
77+
if ref_str and ref_str.startswith("#/components/parameters/"):
78+
param_name = ref_str.split("/")[-1]
79+
resolved = _component_params.get(param_name)
80+
if resolved is not None:
81+
return resolved
82+
83+
return None
84+
5785

5886
# Helper functions for isinstance checks across OpenAPI versions
5987
def is_response_type(obj) -> bool:
@@ -97,9 +125,7 @@ def generate_body_param(operation: Operation) -> Union[str, None]:
97125
if operation.requestBody is None:
98126
return None
99127
else:
100-
if isinstance(operation.requestBody, Reference30) or isinstance(
101-
operation.requestBody, Reference31
102-
):
128+
if isinstance(operation.requestBody, Reference30) or isinstance(operation.requestBody, Reference31):
103129
return "data.dict()"
104130

105131
if operation.requestBody.content is None:
@@ -113,9 +139,7 @@ def generate_body_param(operation: Operation) -> Union[str, None]:
113139
if media_type is None:
114140
return None # pragma: no cover
115141

116-
if isinstance(
117-
media_type.media_type_schema, (Reference, Reference30, Reference31)
118-
):
142+
if isinstance(media_type.media_type_schema, (Reference, Reference30, Reference31)):
119143
return "data.dict()"
120144
elif hasattr(media_type.media_type_schema, "ref"):
121145
# Handle Reference objects from different OpenAPI versions
@@ -127,9 +151,7 @@ def generate_body_param(operation: Operation) -> Union[str, None]:
127151
elif schema.type == "object":
128152
return "data"
129153
else:
130-
raise Exception(
131-
f"Unsupported schema type for request body: {schema.type}"
132-
) # pragma: no cover
154+
raise Exception(f"Unsupported schema type for request body: {schema.type}") # pragma: no cover
133155
else:
134156
raise Exception(
135157
f"Unsupported schema type for request body: {type(media_type.media_type_schema)}"
@@ -153,32 +175,26 @@ def _generate_params_from_content(content: Any):
153175
default_params = ""
154176
if operation.parameters is not None:
155177
for param in operation.parameters:
156-
if not isinstance(param, (Parameter30, Parameter31)):
157-
continue # pragma: no cover
178+
# Resolve parameter references to their actual definitions
179+
resolved_param = _resolve_parameter_ref(param)
180+
if resolved_param is None:
181+
continue # Skip if we can't resolve the reference
182+
param = resolved_param
158183
converted_result = ""
159184
required = False
160185
param_name_cleaned = common.normalize_symbol(param.name)
161186

162-
if isinstance(param.param_schema, Schema30) or isinstance(
163-
param.param_schema, Schema31
164-
):
187+
if isinstance(param.param_schema, Schema30) or isinstance(param.param_schema, Schema31):
165188
converted_result = (
166189
f"{param_name_cleaned} : {type_converter(param.param_schema, param.required).converted_type}"
167190
+ ("" if param.required else " = None")
168191
)
169192
required = param.required
170-
elif isinstance(param.param_schema, Reference30) or isinstance(
171-
param.param_schema, Reference31
172-
):
173-
converted_result = (
174-
f"{param_name_cleaned} : {param.param_schema.ref.split('/')[-1] }"
175-
+ (
176-
""
177-
if isinstance(param, Reference30)
178-
or isinstance(param, Reference31)
179-
or param.required
180-
else " = None"
181-
)
193+
elif isinstance(param.param_schema, Reference30) or isinstance(param.param_schema, Reference31):
194+
converted_result = f"{param_name_cleaned} : {param.param_schema.ref.split('/')[-1]}" + (
195+
""
196+
if isinstance(param, Reference30) or isinstance(param, Reference31) or param.required
197+
else " = None"
182198
)
183199
required = isinstance(param, Reference) or param.required
184200

@@ -194,17 +210,11 @@ def _generate_params_from_content(content: Any):
194210
"application/octet-stream",
195211
]
196212

197-
if operation.requestBody is not None and not is_reference_type(
198-
operation.requestBody
199-
):
213+
if operation.requestBody is not None and not is_reference_type(operation.requestBody):
200214
# Safe access only if it's a concrete RequestBody object
201215
rb_content = getattr(operation.requestBody, "content", None)
202-
if isinstance(rb_content, dict) and any(
203-
rb_content.get(i) is not None for i in operation_request_body_types
204-
):
205-
get_keyword = [
206-
i for i in operation_request_body_types if rb_content.get(i)
207-
][0]
216+
if isinstance(rb_content, dict) and any(rb_content.get(i) is not None for i in operation_request_body_types):
217+
get_keyword = [i for i in operation_request_body_types if rb_content.get(i)][0]
208218
content = rb_content.get(get_keyword)
209219
if content is not None and hasattr(content, "media_type_schema"):
210220
mts = getattr(content, "media_type_schema", None)
@@ -214,9 +224,7 @@ def _generate_params_from_content(content: Any):
214224
):
215225
params += f"{_generate_params_from_content(mts)}, "
216226
else: # pragma: no cover
217-
raise Exception(
218-
f"Unsupported media type schema for {str(operation)}: {type(mts)}"
219-
)
227+
raise Exception(f"Unsupported media type schema for {str(operation)}: {type(mts)}")
220228
# else: silently ignore unsupported body shapes (could extend later)
221229
# Replace - with _ in params
222230
params = params.replace("-", "_")
@@ -225,9 +233,7 @@ def _generate_params_from_content(content: Any):
225233
return params + default_params
226234

227235

228-
def generate_operation_id(
229-
operation: Operation, http_op: str, path_name: Optional[str] = None
230-
) -> str:
236+
def generate_operation_id(operation: Operation, http_op: str, path_name: Optional[str] = None) -> str:
231237
if operation.operationId is not None:
232238
return common.normalize_symbol(operation.operationId)
233239
elif path_name is not None:
@@ -238,17 +244,18 @@ def generate_operation_id(
238244
) # pragma: no cover
239245

240246

241-
def _generate_params(
242-
operation: Operation, param_in: Literal["query", "header"] = "query"
243-
):
247+
def _generate_params(operation: Operation, param_in: Literal["query", "header"] = "query"):
244248
if operation.parameters is None:
245249
return []
246250

247251
params = []
248252
for param in operation.parameters:
249-
if isinstance(param, (Parameter30, Parameter31)) and param.param_in == param_in:
250-
param_name_cleaned = common.normalize_symbol(param.name)
251-
params.append(f"{param.name!r} : {param_name_cleaned}")
253+
# Resolve parameter references to their actual definitions
254+
resolved_param = _resolve_parameter_ref(param)
255+
if resolved_param is None or resolved_param.param_in != param_in:
256+
continue # Skip if we can't resolve the reference
257+
param_name_cleaned = common.normalize_symbol(resolved_param.name)
258+
params.append(f"{resolved_param.name!r} : {param_name_cleaned}")
252259

253260
return params
254261

@@ -284,9 +291,7 @@ def generate_return_type(operation: Operation) -> OpReturnType:
284291
media_type_schema = create_media_type_for_reference(chosen_response)
285292

286293
if media_type_schema is None:
287-
return OpReturnType(
288-
type=None, status_code=good_responses[0][0], complex_type=False
289-
)
294+
return OpReturnType(type=None, status_code=good_responses[0][0], complex_type=False)
290295

291296
if is_media_type(media_type_schema):
292297
inner_schema = getattr(media_type_schema, "media_type_schema", None)
@@ -303,25 +308,18 @@ def generate_return_type(operation: Operation) -> OpReturnType:
303308
)
304309
elif is_schema_type(inner_schema):
305310
converted_result = type_converter(inner_schema, True) # type: ignore
306-
if "array" in converted_result.original_type and isinstance(
307-
converted_result.import_types, list
308-
):
311+
if "array" in converted_result.original_type and isinstance(converted_result.import_types, list):
309312
matched = re.findall(r"List\[(.+)\]", converted_result.converted_type)
310313
if len(matched) > 0:
311314
list_type = matched[0]
312315
else: # pragma: no cover
313-
raise Exception(
314-
f"Unable to parse list type from {converted_result.converted_type}"
315-
)
316+
raise Exception(f"Unable to parse list type from {converted_result.converted_type}")
316317
else:
317318
list_type = None
318319
return OpReturnType(
319320
type=converted_result,
320321
status_code=good_responses[0][0],
321-
complex_type=bool(
322-
converted_result.import_types
323-
and len(converted_result.import_types) > 0
324-
),
322+
complex_type=bool(converted_result.import_types and len(converted_result.import_types) > 0),
325323
list_type=list_type,
326324
)
327325
else: # pragma: no cover
@@ -337,18 +335,31 @@ def generate_return_type(operation: Operation) -> OpReturnType:
337335

338336

339337
def generate_services(
340-
paths: Dict[str, PathItem], library_config: LibraryConfig
338+
paths: Dict[str, PathItem],
339+
library_config: LibraryConfig,
340+
components: Optional[Components] = None,
341341
) -> List[Service]:
342342
"""
343343
Generates services from a paths object.
344344
:param paths: paths object to be converted
345+
:param library_config: Library configuration
346+
:param components: Optional OpenAPI components for resolving parameter references
345347
:return: List of services
346348
"""
349+
global _component_params
350+
351+
# Build a lookup dict for component parameters if available
352+
if components is not None and hasattr(components, "parameters") and components.parameters is not None:
353+
_component_params = {}
354+
for param_name, param_or_ref in components.parameters.items():
355+
if isinstance(param_or_ref, (Parameter30, Parameter31)):
356+
_component_params[param_name] = param_or_ref
357+
else:
358+
_component_params = None
359+
347360
jinja_env = create_jinja_env()
348361

349-
def generate_service_operation(
350-
op: Operation, path_name: str, async_type: bool
351-
) -> ServiceOperation:
362+
def generate_service_operation(op: Operation, path_name: str, async_type: bool) -> ServiceOperation:
352363
# Merge path-level parameters (always required by spec) into the
353364
# operation-level parameters so they get turned into function args.
354365
try:
@@ -362,36 +373,25 @@ def generate_service_operation(
362373
if isinstance(p, (Parameter30, Parameter31)):
363374
existing_names.add(p.name)
364375
for p in path_level_params:
365-
if (
366-
isinstance(p, (Parameter30, Parameter31))
367-
and p.name not in existing_names
368-
):
376+
if isinstance(p, (Parameter30, Parameter31)) and p.name not in existing_names:
369377
if op.parameters is None:
370378
op.parameters = [] # type: ignore
371379
op.parameters.append(p) # type: ignore
372380
except Exception: # pragma: no cover
373-
print(
374-
f"Error merging path-level parameters for {path_name}"
375-
) # pragma: no cover
381+
print(f"Error merging path-level parameters for {path_name}") # pragma: no cover
376382
pass
377383

378384
params = generate_params(op)
379385
# Fallback: ensure all {placeholders} in path are present as function params
380386
try:
381-
placeholder_names = [
382-
m.group(1) for m in re.finditer(r"\{([^}/]+)\}", path_name)
383-
]
384-
existing_param_names = {
385-
p.split(":")[0].strip() for p in params.split(",") if ":" in p
386-
}
387+
placeholder_names = [m.group(1) for m in re.finditer(r"\{([^}/]+)\}", path_name)]
388+
existing_param_names = {p.split(":")[0].strip() for p in params.split(",") if ":" in p}
387389
for ph in placeholder_names:
388390
norm_ph = common.normalize_symbol(ph)
389391
if norm_ph not in existing_param_names and norm_ph:
390392
params = f"{norm_ph}: Any, " + params
391393
except Exception: # pragma: no cover
392-
print(
393-
f"Error ensuring path placeholders in params for {path_name}"
394-
) # pragma: no cover
394+
print(f"Error ensuring path placeholders in params for {path_name}") # pragma: no cover
395395
pass
396396
operation_id = generate_operation_id(op, http_operation, path_name)
397397
query_params = generate_query_params(op)
@@ -415,9 +415,7 @@ def generate_service_operation(
415415
use_orjson=common.get_use_orjson(),
416416
)
417417

418-
so.content = jinja_env.get_template(library_config.template_name).render(
419-
**so.model_dump()
420-
)
418+
so.content = jinja_env.get_template(library_config.template_name).render(**so.model_dump())
421419

422420
if op.tags is not None and len(op.tags) > 0:
423421
so.tag = normalize_symbol(op.tags[0])
@@ -457,16 +455,8 @@ def generate_service_operation(
457455
services.append(
458456
Service(
459457
file_name=f"{tag}_service",
460-
operations=[
461-
so for so in service_ops if so.tag == tag and not so.async_client
462-
],
463-
content="\n".join(
464-
[
465-
so.content
466-
for so in service_ops
467-
if so.tag == tag and not so.async_client
468-
]
469-
),
458+
operations=[so for so in service_ops if so.tag == tag and not so.async_client],
459+
content="\n".join([so.content for so in service_ops if so.tag == tag and not so.async_client]),
470460
async_client=False,
471461
library_import=library_config.library_name,
472462
use_orjson=common.get_use_orjson(),
@@ -477,16 +467,8 @@ def generate_service_operation(
477467
services.append(
478468
Service(
479469
file_name=f"async_{tag}_service",
480-
operations=[
481-
so for so in service_ops if so.tag == tag and so.async_client
482-
],
483-
content="\n".join(
484-
[
485-
so.content
486-
for so in service_ops
487-
if so.tag == tag and so.async_client
488-
]
489-
),
470+
operations=[so for so in service_ops if so.tag == tag and so.async_client],
471+
content="\n".join([so.content for so in service_ops if so.tag == tag and so.async_client]),
490472
async_client=True,
491473
library_import=library_config.library_name,
492474
use_orjson=common.get_use_orjson(),

0 commit comments

Comments
 (0)