Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mypyc/irbuild/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def build_type_map(
is_abstract=cdef.info.is_abstract,
is_final_class=cdef.info.is_final,
)
class_ir.is_ext_class = is_extension_class(cdef)
class_ir.is_ext_class = is_extension_class(module.path, cdef, errors)
if class_ir.is_ext_class:
class_ir.deletable = cdef.info.deletable_attributes.copy()
# If global optimizations are disabled, turn of tracking of class children
Expand Down
70 changes: 62 additions & 8 deletions mypyc/irbuild/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from mypy.semanal import refers_to_fullname
from mypy.types import FINAL_DECORATOR_NAMES
from mypyc.errors import Errors

DATACLASS_DECORATORS = {"dataclasses.dataclass", "attr.s", "attr.attrs"}

Expand Down Expand Up @@ -125,15 +126,68 @@ def get_mypyc_attrs(stmt: ClassDef | Decorator) -> dict[str, Any]:
return attrs


def is_extension_class(cdef: ClassDef) -> bool:
if any(
not is_trait_decorator(d)
and not is_dataclass_decorator(d)
and not get_mypyc_attr_call(d)
and not is_final_decorator(d)
for d in cdef.decorators
):
def is_extension_class(path: str, cdef: ClassDef, errors: Errors) -> bool:
# Check for @mypyc_attr(native_class=True/False) decorator.
explicit_native_class = get_explicit_native_class(path, cdef, errors)

# Classes with native_class=False are explicitly marked as non extension.
if explicit_native_class is False:
return False

implicit_extension_class = is_implicit_extension_class(cdef)

# Classes with native_class=True should be extension classes, but they might
# not be able to be due to other reasons. Print an error in that case.
if explicit_native_class is True and not implicit_extension_class:
errors.error(
"Class is marked as native_class=True but it can't be a native class", path, cdef.line
)

return implicit_extension_class


def get_explicit_native_class(path: str, cdef: ClassDef, errors: Errors) -> bool | None:
"""Return value of @mypyc_attr(native_class=True/False) decorator.

Look for a @mypyc_attr decorator with native_class=True/False and return
the value assigned or None if it doesn't exist. Other values are an error.
"""

for d in cdef.decorators:
mypyc_attr_call = get_mypyc_attr_call(d)
if not mypyc_attr_call:
continue

for i, name in enumerate(mypyc_attr_call.arg_names):
if name != "native_class":
continue

arg = mypyc_attr_call.args[i]
if not isinstance(arg, NameExpr):
errors.error("native_class must be used with True or False only", path, cdef.line)
return None

if arg.name == "False":
return False
elif arg.name == "True":
return True
else:
errors.error("native_class must be used with True or False only", path, cdef.line)
return None
return None


def is_implicit_extension_class(cdef: ClassDef) -> bool:
for d in cdef.decorators:
# Classes that have any decorator other than supported decorators, are not extension classes
if (
not is_trait_decorator(d)
and not is_dataclass_decorator(d)
and not get_mypyc_attr_call(d)
and not is_final_decorator(d)
):
return False

if cdef.info.typeddict_type:
return False
if cdef.info.is_named_tuple:
Expand Down
2 changes: 2 additions & 0 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,10 @@ def next(i: Iterator[_T]) -> _T: pass
def next(i: Iterator[_T], default: _T) -> _T: pass
def hash(o: object) -> int: ...
def globals() -> Dict[str, Any]: ...
def hasattr(obj: object, name: str) -> bool: ...
def getattr(obj: object, name: str, default: Any = None) -> Any: ...
def setattr(obj: object, name: str, value: Any) -> None: ...
def delattr(obj: object, name: str) -> None: ...
def enumerate(x: Iterable[_T]) -> Iterator[Tuple[int, _T]]: ...
@overload
def zip(x: Iterable[_T], y: Iterable[_S]) -> Iterator[Tuple[_T, _S]]: ...
Expand Down
25 changes: 25 additions & 0 deletions mypyc/test-data/irbuild-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -1345,3 +1345,28 @@ class SomeEnum(Enum):

ALIAS = Literal[SomeEnum.AVALUE]
ALIAS2 = Union[Literal[SomeEnum.AVALUE], None]

[case testMypycAttrNativeClassErrors]
from mypy_extensions import mypyc_attr

@mypyc_attr(native_class=False)
class AnnontatedNonExtensionClass:
pass

@mypyc_attr(native_class=False)
class DerivedExplicitNonNativeClass(AnnontatedNonExtensionClass):
pass


def decorator(cls):
return cls

@mypyc_attr(native_class=True)
@decorator
class NonNativeClassContradiction(): # E: Class is marked as native_class=True but it can't be a native class
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test also a successful case of using native_class=True (in a run test so that the runtime behavior can be tested).

pass


@mypyc_attr(native_class="yes")
class BadUse(): # E: native_class must be used with True or False only
pass
45 changes: 45 additions & 0 deletions mypyc/test-data/run-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -2829,3 +2829,48 @@ Traceback (most recent call last):
File "native.py", line 5, in __del__
raise Exception("e2")
Exception: e2

[case testNonExtensionClassAttr]
from mypy_extensions import mypyc_attr
from testutil import assertRaises

@mypyc_attr(native_class=False)
class AnnontatedNonExtensionClass:
pass

class DerivedClass(AnnontatedNonExtensionClass):
pass

class ImplicitExtensionClass():
pass

def test_function():
setattr(AnnontatedNonExtensionClass, 'attr_class', 5)
assert(hasattr(AnnontatedNonExtensionClass, 'attr_class') == True)
assert(getattr(AnnontatedNonExtensionClass, 'attr_class') == 5)
delattr(AnnontatedNonExtensionClass, 'attr_class')
assert(hasattr(AnnontatedNonExtensionClass, 'attr_class') == False)

inst = AnnontatedNonExtensionClass()
setattr(inst, 'attr_instance', 6)
assert(hasattr(inst, 'attr_instance') == True)
assert(getattr(inst, 'attr_instance') == 6)
delattr(inst, 'attr_instance')
assert(hasattr(inst, 'attr_instance') == False)

setattr(DerivedClass, 'attr_class', 5)
assert(hasattr(DerivedClass, 'attr_class') == True)
assert(getattr(DerivedClass, 'attr_class') == 5)
delattr(DerivedClass, 'attr_class')
assert(hasattr(DerivedClass, 'attr_class') == False)

derived_inst = DerivedClass()
setattr(derived_inst, 'attr_instance', 6)
assert(hasattr(derived_inst, 'attr_instance') == True)
assert(getattr(derived_inst, 'attr_instance') == 6)
delattr(derived_inst, 'attr_instance')
assert(hasattr(derived_inst, 'attr_instance') == False)

ext_inst = ImplicitExtensionClass()
with assertRaises(AttributeError):
setattr(ext_inst, 'attr_instance', 6)