diff --git a/end_to_end_tests/baseline_openapi_3.0.json b/end_to_end_tests/baseline_openapi_3.0.json index f452e6f18..c47048218 100644 --- a/end_to_end_tests/baseline_openapi_3.0.json +++ b/end_to_end_tests/baseline_openapi_3.0.json @@ -1918,6 +1918,30 @@ } ] }, + "AllOfRequiredBase": { + "type": "object", + "properties": { + "bar": { + "type": "string", + "description": "The bar property" + }, + "baz": { + "type": "string", + "description": "The baz property" + } + } + }, + "AllOfRequiredDerived": { + "allOf": [ + { + "$ref": "#/components/schemas/AllOfRequiredBase" + }, + { + "type": "object", + "required": ["bar"] + } + ] + }, "AModel": { "title": "AModel", "required": [ diff --git a/end_to_end_tests/baseline_openapi_3.1.yaml b/end_to_end_tests/baseline_openapi_3.1.yaml index 911ca8842..295e6818a 100644 --- a/end_to_end_tests/baseline_openapi_3.1.yaml +++ b/end_to_end_tests/baseline_openapi_3.1.yaml @@ -1871,6 +1871,27 @@ info: } ] }, + "AllOfRequiredBase": { + "type": "object", + "properties": { + "bar": { + "type": "string", + "description": "The bar property" + }, + "baz": { + "type": "string", + "description": "The baz property" + } + } + }, + "AllOfRequiredDerived": { + "allOf": [ + { "$ref": "#/components/schemas/AllOfRequiredBase" }, + { "type": "object", + "required": ["bar"] + } + ] + }, "AModel": { "title": "AModel", "required": [ diff --git a/end_to_end_tests/golden-record/my_test_api_client/models/__init__.py b/end_to_end_tests/golden-record/my_test_api_client/models/__init__.py index cd897d9fe..c62e4cfa6 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/models/__init__.py +++ b/end_to_end_tests/golden-record/my_test_api_client/models/__init__.py @@ -7,6 +7,8 @@ from .a_model_with_properties_reference_that_are_not_object import AModelWithPropertiesReferenceThatAreNotObject from .all_of_has_properties_but_no_type import AllOfHasPropertiesButNoType from .all_of_has_properties_but_no_type_type_enum import AllOfHasPropertiesButNoTypeTypeEnum +from .all_of_required_base import AllOfRequiredBase +from .all_of_required_derived import AllOfRequiredDerived from .all_of_sub_model import AllOfSubModel from .all_of_sub_model_type_enum import AllOfSubModelTypeEnum from .an_all_of_enum import AnAllOfEnum @@ -100,6 +102,8 @@ "AFormData", "AllOfHasPropertiesButNoType", "AllOfHasPropertiesButNoTypeTypeEnum", + "AllOfRequiredBase", + "AllOfRequiredDerived", "AllOfSubModel", "AllOfSubModelTypeEnum", "AModel", diff --git a/end_to_end_tests/golden-record/my_test_api_client/models/all_of_required_base.py b/end_to_end_tests/golden-record/my_test_api_client/models/all_of_required_base.py new file mode 100644 index 000000000..b06f45c30 --- /dev/null +++ b/end_to_end_tests/golden-record/my_test_api_client/models/all_of_required_base.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, TypeVar + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..types import UNSET, Unset + +T = TypeVar("T", bound="AllOfRequiredBase") + + +@_attrs_define +class AllOfRequiredBase: + """ + Attributes: + bar (str | Unset): The bar property + baz (str | Unset): The baz property + """ + + bar: str | Unset = UNSET + baz: str | Unset = UNSET + additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> dict[str, Any]: + bar = self.bar + + baz = self.baz + + field_dict: dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update({}) + if bar is not UNSET: + field_dict["bar"] = bar + if baz is not UNSET: + field_dict["baz"] = baz + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + d = dict(src_dict) + bar = d.pop("bar", UNSET) + + baz = d.pop("baz", UNSET) + + all_of_required_base = cls( + bar=bar, + baz=baz, + ) + + all_of_required_base.additional_properties = d + return all_of_required_base + + @property + def additional_keys(self) -> list[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/end_to_end_tests/golden-record/my_test_api_client/models/all_of_required_derived.py b/end_to_end_tests/golden-record/my_test_api_client/models/all_of_required_derived.py new file mode 100644 index 000000000..505334d6b --- /dev/null +++ b/end_to_end_tests/golden-record/my_test_api_client/models/all_of_required_derived.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, TypeVar + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..types import UNSET, Unset + +T = TypeVar("T", bound="AllOfRequiredDerived") + + +@_attrs_define +class AllOfRequiredDerived: + """ + Attributes: + bar (str): The bar property + baz (str | Unset): The baz property + """ + + bar: str + baz: str | Unset = UNSET + additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> dict[str, Any]: + bar = self.bar + + baz = self.baz + + field_dict: dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "bar": bar, + } + ) + if baz is not UNSET: + field_dict["baz"] = baz + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + d = dict(src_dict) + bar = d.pop("bar") + + baz = d.pop("baz", UNSET) + + all_of_required_derived = cls( + bar=bar, + baz=baz, + ) + + all_of_required_derived.additional_properties = d + return all_of_required_derived + + @property + def additional_keys(self) -> list[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/openapi_python_client/parser/properties/model_property.py b/openapi_python_client/parser/properties/model_property.py index c304d965e..636b71a34 100644 --- a/openapi_python_client/parser/properties/model_property.py +++ b/openapi_python_client/parser/properties/model_property.py @@ -292,6 +292,11 @@ def _add_if_no_conflict(new_prop: Property) -> PropertyError | None: unprocessed_props.extend(sub_prop.properties.items() if sub_prop.properties else []) required_set.update(sub_prop.required or []) + # Update properties that are marked as required in the schema + for prop_name in required_set: + if prop_name in properties and not properties[prop_name].required: + properties[prop_name] = evolve(properties[prop_name], required=True) + for key, value in unprocessed_props: prop_required = key in required_set prop_or_error: Property | (PropertyError | None) diff --git a/tests/test_parser/test_properties/test_model_property.py b/tests/test_parser/test_properties/test_model_property.py index 4445d441a..f84a31a17 100644 --- a/tests/test_parser/test_properties/test_model_property.py +++ b/tests/test_parser/test_properties/test_model_property.py @@ -543,6 +543,42 @@ def test_duplicate_properties(self, model_property_factory, string_property_fact assert result.optional_props == [prop], "There should only be one copy of duplicate properties" + def test_allof_required_override(self, model_property_factory, string_property_factory, config): + """Test that required field can be overridden in allOf schemas""" + # Simulates: + # FooBase: + # type: object + # properties: + # bar: {type: string} + # baz: {type: string} + # FooCreate: + # allOf: + # - $ref: '#/components/schemas/FooBase' + # - type: object + # required: [bar] + bar_prop = string_property_factory(name="bar", required=False) + baz_prop = string_property_factory(name="baz", required=False) + + data = oai.Schema.model_construct( + allOf=[ + oai.Reference.model_construct(ref="#/FooBase"), + oai.Schema.model_construct(type="object", required=["bar"]), + ] + ) + schemas = Schemas( + classes_by_reference={ + "/FooBase": model_property_factory(required_properties=[], optional_properties=[bar_prop, baz_prop]), + } + ) + + result = _process_properties(data=data, schemas=schemas, class_name="FooCreate", config=config, roots={"root"}) + + # bar should now be required, baz should remain optional + assert len(result.required_props) == 1 + assert result.required_props[0].name == "bar" + assert len(result.optional_props) == 1 + assert result.optional_props[0].name == "baz" + @pytest.mark.parametrize("first_required", [True, False]) @pytest.mark.parametrize("second_required", [True, False]) def test_mixed_requirements(