|
1 | 1 | import logging |
2 | 2 | import hashlib |
3 | 3 | from typing import List, Optional |
| 4 | +from enum import Enum |
4 | 5 |
|
5 | 6 | from sqlalchemy import event, UniqueConstraint |
6 | 7 | from sqlalchemy.orm.attributes import get_history |
7 | | - |
8 | 8 | from sqlmodel import ( |
9 | 9 | Field, |
10 | 10 | Relationship, |
11 | 11 | SQLModel, |
12 | 12 | create_engine, |
13 | 13 | ) |
| 14 | +from policyengine_us.system import system |
| 15 | + |
| 16 | +from policyengine_us_data.storage import STORAGE_FOLDER |
| 17 | + |
14 | 18 |
|
15 | 19 | logging.basicConfig( |
16 | 20 | level=logging.INFO, |
|
20 | 24 | logger = logging.getLogger(__name__) |
21 | 25 |
|
22 | 26 |
|
| 27 | +# An Enum type to ensure the variable exists in policyengine-us |
| 28 | +USVariable = Enum( |
| 29 | + "USVariable", {name: name for name in system.variables.keys()}, type=str |
| 30 | +) |
| 31 | + |
| 32 | + |
23 | 33 | class Stratum(SQLModel, table=True): |
24 | 34 | """Represents a unique population subgroup (stratum).""" |
25 | 35 |
|
@@ -79,7 +89,7 @@ class StratumConstraint(SQLModel, table=True): |
79 | 89 | __tablename__ = "stratum_constraints" |
80 | 90 |
|
81 | 91 | stratum_id: int = Field(foreign_key="strata.stratum_id", primary_key=True) |
82 | | - constraint_variable: str = Field( |
| 92 | + constraint_variable: USVariable = Field( |
83 | 93 | primary_key=True, |
84 | 94 | description="The variable the constraint applies to (e.g., 'age').", |
85 | 95 | ) |
@@ -112,7 +122,7 @@ class Target(SQLModel, table=True): |
112 | 122 | ) |
113 | 123 |
|
114 | 124 | target_id: Optional[int] = Field(default=None, primary_key=True) |
115 | | - variable: str = Field( |
| 125 | + variable: USVariable = Field( |
116 | 126 | description="A variable defined in policyengine-us (e.g., 'income_tax')." |
117 | 127 | ) |
118 | 128 | period: int = Field( |
@@ -171,12 +181,11 @@ def calculate_definition_hash(mapper, connection, target: Stratum): |
171 | 181 | fingerprint_text = "\n".join(constraint_strings) |
172 | 182 | h = hashlib.sha256(fingerprint_text.encode("utf-8")) |
173 | 183 | target.definition_hash = h.hexdigest() |
174 | | - logger.info( |
175 | | - f"Set definition_hash for Stratum to '{target.definition_hash}'" |
176 | | - ) |
177 | 184 |
|
178 | 185 |
|
179 | | -def create_database(db_uri="sqlite:///policy_data.db"): |
| 186 | +def create_database( |
| 187 | + db_uri: str = f"sqlite:///{STORAGE_FOLDER / 'policy_data.db'}", |
| 188 | +): |
180 | 189 | """ |
181 | 190 | Creates a SQLite database and all the defined tables. |
182 | 191 |
|
|
0 commit comments