|
3 | 3 | import contextlib |
4 | 4 | import datetime |
5 | 5 | import re |
6 | | -from collections.abc import Iterator |
| 6 | +from collections.abc import Iterator, Mapping |
7 | 7 | from typing import TYPE_CHECKING, Any, Optional, Protocol, Union, cast, runtime_checkable |
8 | 8 | from uuid import UUID |
9 | 9 |
|
@@ -216,64 +216,42 @@ def __init_subclass__(cls, **kwargs: Any) -> None: |
216 | 216 | they share the parent's table. This hook enforces that rule. |
217 | 217 |
|
218 | 218 | The detection logic identifies STI children by checking: |
219 | | - 1. Class has ``polymorphic_identity`` in ``__mapper_args__`` (explicit STI child marker) |
| 219 | + 1. Class doesn't explicitly define ``__tablename__`` in its own ``__dict__`` |
220 | 220 | 2. AND doesn't have ``concrete=True`` (which would make it CTI) |
221 | | - 3. AND doesn't have ``polymorphic_on`` itself (which would make it a base) |
222 | | - 4. AND doesn't explicitly define ``__tablename__`` in its own ``__dict__`` |
| 221 | + 3. AND doesn't define ``polymorphic_on`` in its own ``__mapper_args__`` (which would make it a base) |
| 222 | + 4. AND inherits from a parent that defines ``polymorphic_on`` in ``__mapper_args__`` (STI hierarchy) |
223 | 223 |
|
224 | | - For children without ``polymorphic_identity`` but with a parent that has |
225 | | - ``polymorphic_on``, SQLAlchemy treats them as abstract intermediate classes |
226 | | - and will issue a warning. We don't modify ``__tablename__`` for these cases. |
| 224 | + For intermediate classes without ``polymorphic_identity`` but with a parent that has |
| 225 | + ``polymorphic_on``, SQLAlchemy can emit a warning. When an intermediate class should |
| 226 | + not be instantiated, set ``polymorphic_abstract=True`` in ``__mapper_args__`` or mark it |
| 227 | + with ``__abstract__ = True``. |
227 | 228 |
|
228 | 229 | This allows both usage patterns: |
229 | 230 | 1. Auto-generated names (don't set ``__tablename__`` on parent) |
230 | 231 | 2. Explicit names (set ``__tablename__`` on parent, STI still works) |
231 | 232 | """ |
232 | | - # IMPORTANT: Modify the class BEFORE calling super().__init_subclass__() |
233 | | - # because super() triggers SQLAlchemy's declarative processing |
234 | | - mapper_args = getattr(cls, "__mapper_args__", {}) |
235 | | - |
236 | | - # Skip if this class explicitly defines its own __tablename__ |
237 | 233 | if "__tablename__" in cls.__dict__: |
238 | 234 | super().__init_subclass__(**kwargs) |
239 | 235 | return |
240 | 236 |
|
241 | | - # Skip if this is CTI (concrete table inheritance) |
242 | | - if mapper_args.get("concrete", False): |
| 237 | + cls_dict = cast("Mapping[str, Any]", cls.__dict__) |
| 238 | + own_mapper_args = cls_dict.get("__mapper_args__") |
| 239 | + own_mapper_args_dict = cast("dict[str, Any]", own_mapper_args) if isinstance(own_mapper_args, dict) else {} |
| 240 | + |
| 241 | + if own_mapper_args_dict.get("concrete", False): |
| 242 | + super().__init_subclass__(**kwargs) |
| 243 | + return |
| 244 | + |
| 245 | + if "polymorphic_on" in own_mapper_args_dict: |
243 | 246 | super().__init_subclass__(**kwargs) |
244 | 247 | return |
245 | 248 |
|
246 | | - # Check if this class might be an STI child |
247 | | - # An STI child either has polymorphic_identity in its own __mapper_args__, |
248 | | - # or inherits from a parent with polymorphic_on |
249 | | - is_potential_sti_child = False |
250 | | - |
251 | | - # Check if THIS class (not inherited) defines polymorphic_on |
252 | | - # If it does, it's a base class, not a child |
253 | | - if "__mapper_args__" in cls.__dict__: |
254 | | - own_mapper_args = cls.__dict__["__mapper_args__"] |
255 | | - if "polymorphic_on" in own_mapper_args: |
256 | | - # This is a base class, not a child - skip |
257 | | - super().__init_subclass__(**kwargs) |
258 | | - return |
259 | | - |
260 | | - # Check if any parent has polymorphic_on (indicates we're in an STI hierarchy) |
261 | 249 | for parent in cls.__mro__[1:]: |
262 | | - if not hasattr(parent, "__mapper_args__"): |
263 | | - continue |
264 | | - parent_mapper_args = getattr(parent, "__mapper_args__", {}) |
265 | | - if "polymorphic_on" in parent_mapper_args: |
266 | | - # We're inheriting from a polymorphic base, so we're an STI child |
267 | | - is_potential_sti_child = True |
| 250 | + parent_mapper_args = getattr(parent, "__mapper_args__", None) |
| 251 | + if isinstance(parent_mapper_args, dict) and "polymorphic_on" in parent_mapper_args: |
| 252 | + cls.__tablename__ = None # type: ignore[misc] |
268 | 253 | break |
269 | 254 |
|
270 | | - if is_potential_sti_child and "__tablename__" not in cls.__dict__: |
271 | | - # For STI children that inherited an explicit __tablename__ from a parent, |
272 | | - # we need to explicitly set it to None so SQLAlchemy knows to use the parent's table. |
273 | | - # This overrides the inherited string value. |
274 | | - cls.__tablename__ = None # type: ignore[misc] |
275 | | - |
276 | | - # Now call super() which triggers SQLAlchemy's declarative system |
277 | 255 | super().__init_subclass__(**kwargs) |
278 | 256 |
|
279 | 257 | if TYPE_CHECKING: |
@@ -359,40 +337,20 @@ class Manager(Employee): |
359 | 337 | __tablename__ = "manager" # Independent table |
360 | 338 | __mapper_args__ = {"concrete": True} |
361 | 339 | """ |
362 | | - # Check if class explicitly defines __tablename__ in its own __dict__ |
363 | | - if "__tablename__" in cls.__dict__: |
364 | | - value = cls.__dict__["__tablename__"] |
365 | | - # If explicitly set to None (e.g., by __init_subclass__ for STI), return None |
366 | | - if value is None: |
367 | | - return None |
368 | | - return value |
| 340 | + cls_dict = cast("Mapping[str, Any]", cls.__dict__) |
| 341 | + if "__tablename__" in cls_dict: |
| 342 | + return cast("Optional[str]", cls_dict["__tablename__"]) |
369 | 343 |
|
370 | | - # Check if this is an STI child class that needs auto-detection |
371 | | - # This handles cases where the parent didn't explicitly set __tablename__ |
372 | 344 | mapper_args = getattr(cls, "__mapper_args__", {}) |
| 345 | + mapper_args_dict = cast("dict[str, Any]", mapper_args) if isinstance(mapper_args, dict) else {} |
| 346 | + if mapper_args_dict.get("concrete", False) or "polymorphic_on" in mapper_args_dict: |
| 347 | + return table_name_regexp.sub(r"_\1", cls.__name__).lower() |
373 | 348 |
|
374 | | - # Skip STI detection if this class defines polymorphic_on (it's a base, not a child) |
375 | | - if "polymorphic_on" not in mapper_args: |
376 | | - is_sti_child = False |
377 | | - |
378 | | - # Check explicit STI marker |
379 | | - if "polymorphic_identity" in mapper_args: |
380 | | - is_sti_child = True |
381 | | - else: |
382 | | - # Check if any parent has polymorphic_on (indicates STI hierarchy) |
383 | | - for parent in cls.__mro__[1:]: |
384 | | - if not hasattr(parent, "__mapper_args__"): |
385 | | - continue |
386 | | - parent_mapper_args = getattr(parent, "__mapper_args__", {}) |
387 | | - if "polymorphic_on" in parent_mapper_args: |
388 | | - is_sti_child = True |
389 | | - break |
390 | | - |
391 | | - if is_sti_child: |
392 | | - # This is an STI child - return None to use parent's table |
| 349 | + for parent in cls.__mro__[1:]: |
| 350 | + parent_mapper_args = getattr(parent, "__mapper_args__", None) |
| 351 | + if isinstance(parent_mapper_args, dict) and "polymorphic_on" in parent_mapper_args: |
393 | 352 | return None |
394 | 353 |
|
395 | | - # Generate table name from class name using snake_case conversion |
396 | 354 | return table_name_regexp.sub(r"_\1", cls.__name__).lower() |
397 | 355 |
|
398 | 356 |
|
|
0 commit comments