|
13 | 13 | from uuid import UUID |
14 | 14 |
|
15 | 15 | import sqlalchemy |
16 | | -from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text, update |
| 16 | +from sqlalchemy import ( |
| 17 | + ForeignKeyConstraint, |
| 18 | + MetaData, |
| 19 | + Table, |
| 20 | + cast as cast_, |
| 21 | + func, |
| 22 | + text, |
| 23 | + update, |
| 24 | +) |
17 | 25 | from sqlalchemy.engine import CursorResult, Engine |
18 | 26 | from sqlalchemy.exc import ( |
19 | 27 | DatabaseError, |
|
26 | 34 | from sqlalchemy.orm import DeclarativeBase |
27 | 35 | from sqlalchemy.orm.session import Session |
28 | 36 | from sqlalchemy.schema import AddConstraint, CreateTable, DropConstraint |
29 | | -from sqlalchemy.sql.expression import true |
| 37 | +from sqlalchemy.sql.expression import and_, true |
30 | 38 | from sqlalchemy.sql.lambdas import StatementLambdaElement |
| 39 | +from sqlalchemy.types import BINARY |
31 | 40 |
|
32 | 41 | from homeassistant.core import HomeAssistant |
33 | 42 | from homeassistant.util.enum import try_parse_enum |
@@ -2044,14 +2053,74 @@ def _apply_update(self) -> None: |
2044 | 2053 | class _SchemaVersion51Migrator(_SchemaVersionMigrator, target_version=51): |
2045 | 2054 | def _apply_update(self) -> None: |
2046 | 2055 | """Version specific update method.""" |
2047 | | - # Add unit class column to StatisticsMeta |
| 2056 | + # Replaced with version 52 which corrects issues with MySQL string comparisons. |
| 2057 | + |
| 2058 | + |
| 2059 | +class _SchemaVersion52Migrator(_SchemaVersionMigrator, target_version=52): |
| 2060 | + def _apply_update(self) -> None: |
| 2061 | + """Version specific update method.""" |
| 2062 | + if self.engine.dialect.name == SupportedDialect.MYSQL: |
| 2063 | + self._apply_update_mysql() |
| 2064 | + else: |
| 2065 | + self._apply_update_postgresql_sqlite() |
| 2066 | + |
| 2067 | + def _apply_update_mysql(self) -> None: |
| 2068 | + """Version specific update method for mysql.""" |
| 2069 | + _add_columns(self.session_maker, "statistics_meta", ["unit_class VARCHAR(255)"]) |
| 2070 | + with session_scope(session=self.session_maker()) as session: |
| 2071 | + connection = session.connection() |
| 2072 | + for conv in _PRIMARY_UNIT_CONVERTERS: |
| 2073 | + case_sensitive_units = { |
| 2074 | + u.encode("utf-8") if u else u for u in conv.VALID_UNITS |
| 2075 | + } |
| 2076 | + # Reset unit_class to None for entries that do not match |
| 2077 | + # the valid units (case sensitive) but matched before due to |
| 2078 | + # case insensitive comparisons. |
| 2079 | + connection.execute( |
| 2080 | + update(StatisticsMeta) |
| 2081 | + .where( |
| 2082 | + and_( |
| 2083 | + StatisticsMeta.unit_of_measurement.in_(conv.VALID_UNITS), |
| 2084 | + cast_(StatisticsMeta.unit_of_measurement, BINARY).not_in( |
| 2085 | + case_sensitive_units |
| 2086 | + ), |
| 2087 | + ) |
| 2088 | + ) |
| 2089 | + .values(unit_class=None) |
| 2090 | + ) |
| 2091 | + # Do an explicitly case sensitive match (actually binary) to set the |
| 2092 | + # correct unit_class. This is needed because we use the case sensitive |
| 2093 | + # utf8mb4_unicode_ci collation. |
| 2094 | + connection.execute( |
| 2095 | + update(StatisticsMeta) |
| 2096 | + .where( |
| 2097 | + and_( |
| 2098 | + cast_(StatisticsMeta.unit_of_measurement, BINARY).in_( |
| 2099 | + case_sensitive_units |
| 2100 | + ), |
| 2101 | + StatisticsMeta.unit_class.is_(None), |
| 2102 | + ) |
| 2103 | + ) |
| 2104 | + .values(unit_class=conv.UNIT_CLASS) |
| 2105 | + ) |
| 2106 | + |
| 2107 | + def _apply_update_postgresql_sqlite(self) -> None: |
| 2108 | + """Version specific update method for postgresql and sqlite.""" |
2048 | 2109 | _add_columns(self.session_maker, "statistics_meta", ["unit_class VARCHAR(255)"]) |
2049 | 2110 | with session_scope(session=self.session_maker()) as session: |
2050 | 2111 | connection = session.connection() |
2051 | 2112 | for conv in _PRIMARY_UNIT_CONVERTERS: |
| 2113 | + # Set the correct unit_class. Unlike MySQL, Postgres and SQLite |
| 2114 | + # have case sensitive string comparisons by default, so we |
| 2115 | + # can directly match on the valid units. |
2052 | 2116 | connection.execute( |
2053 | 2117 | update(StatisticsMeta) |
2054 | | - .where(StatisticsMeta.unit_of_measurement.in_(conv.VALID_UNITS)) |
| 2118 | + .where( |
| 2119 | + and_( |
| 2120 | + StatisticsMeta.unit_of_measurement.in_(conv.VALID_UNITS), |
| 2121 | + StatisticsMeta.unit_class.is_(None), |
| 2122 | + ) |
| 2123 | + ) |
2055 | 2124 | .values(unit_class=conv.UNIT_CLASS) |
2056 | 2125 | ) |
2057 | 2126 |
|
|
0 commit comments