Skip to content

Commit 6fe0104

Browse files
committed
Define db schema using SQLAlchemy
1 parent 6533418 commit 6fe0104

File tree

3 files changed

+97
-29
lines changed

3 files changed

+97
-29
lines changed

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ dependencies = [
3636
"pypubsub",
3737
"tomlkit",
3838
"duckdb",
39-
"fsspec"
39+
"fsspec",
40+
"sqlalchemy",
41+
"duckdb-engine"
4042
]
4143
dynamic = ["version"]
4244

src/muse/new_input/readers.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,48 @@
1-
import duckdb
21
import numpy as np
32
import xarray as xr
3+
from sqlalchemy import CheckConstraint, ForeignKey
4+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
5+
6+
7+
class TableBase(DeclarativeBase):
8+
pass
9+
10+
11+
class Regions(TableBase):
12+
__tablename__ = "regions"
13+
14+
name: Mapped[str] = mapped_column(primary_key=True)
15+
16+
17+
class Commodities(TableBase):
18+
__tablename__ = "commodities"
19+
20+
name: Mapped[str] = mapped_column(primary_key=True)
21+
type: Mapped[str] = mapped_column(
22+
CheckConstraint("type IN ('energy', 'service', 'material', 'environmental')")
23+
)
24+
unit: Mapped[str]
25+
26+
27+
class Demand(TableBase):
28+
__tablename__ = "demand"
29+
30+
year: Mapped[int] = mapped_column(primary_key=True, autoincrement=False)
31+
commodity: Mapped[Commodities] = mapped_column(
32+
ForeignKey("commodities.name"), primary_key=True
33+
)
34+
region: Mapped[Regions] = mapped_column(
35+
ForeignKey("regions.name"), primary_key=True
36+
)
37+
demand: Mapped[float]
438

539

640
def read_inputs(data_dir):
7-
data = {}
8-
con = duckdb.connect(":memory:")
41+
from sqlalchemy import create_engine
42+
43+
engine = create_engine("duckdb:///:memory:")
44+
TableBase.metadata.create_all(engine)
45+
con = engine.raw_connection().driver_connection
946

1047
with open(data_dir / "regions.csv") as f:
1148
regions = read_regions_csv(f, con) # noqa: F841
@@ -16,32 +53,20 @@ def read_inputs(data_dir):
1653
with open(data_dir / "demand.csv") as f:
1754
demand = read_demand_csv(f, con) # noqa: F841
1855

56+
data = {}
1957
data["global_commodities"] = calculate_global_commodities(commodities)
2058
return data
2159

2260

2361
def read_regions_csv(buffer_, con):
24-
sql = """CREATE TABLE regions (
25-
name VARCHAR PRIMARY KEY,
26-
);
27-
"""
28-
con.sql(sql)
2962
rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841
30-
con.sql("INSERT INTO regions SELECT name FROM rel;")
63+
con.execute("INSERT INTO regions SELECT name FROM rel;")
3164
return con.sql("SELECT name from regions").fetchnumpy()
3265

3366

3467
def read_commodities_csv(buffer_, con):
35-
sql = """CREATE TABLE commodities (
36-
name VARCHAR PRIMARY KEY,
37-
type VARCHAR CHECK (type IN ('energy', 'service', 'material', 'environmental')),
38-
unit VARCHAR,
39-
);
40-
"""
41-
con.sql(sql)
4268
rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841
4369
con.sql("INSERT INTO commodities SELECT name, type, unit FROM rel;")
44-
4570
return con.sql("select name, type, unit from commodities").fetchnumpy()
4671

4772

@@ -63,14 +88,6 @@ def calculate_global_commodities(commodities):
6388

6489

6590
def read_demand_csv(buffer_, con):
66-
sql = """CREATE TABLE demand (
67-
year BIGINT,
68-
commodity VARCHAR REFERENCES commodities(name),
69-
region VARCHAR REFERENCES regions(name),
70-
demand DOUBLE,
71-
);
72-
"""
73-
con.sql(sql)
7491
rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841
7592
con.sql("INSERT INTO demand SELECT year, commodity_name, region, demand FROM rel;")
7693
return con.sql("SELECT * from demand").fetchnumpy()

tests/test_readers.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -868,7 +868,14 @@ def default_new_input(tmp_path):
868868

869869
@fixture
870870
def con():
871-
return duckdb.connect(":memory:")
871+
from muse.new_input.readers import TableBase
872+
from sqlalchemy import create_engine
873+
from sqlalchemy.orm import Session
874+
875+
engine = create_engine("duckdb:///:memory:")
876+
session = Session(engine)
877+
TableBase.metadata.create_all(engine)
878+
return session.connection().connection
872879

873880

874881
@fixture
@@ -899,7 +906,15 @@ def test_read_regions(populate_regions):
899906
assert populate_regions["name"] == np.array(["R1"])
900907

901908

902-
def test_read_new_global_commodities(populate_commodities):
909+
def test_read_regions_primary_key_constraint(default_new_input, con):
910+
from muse.new_input.readers import read_regions_csv
911+
912+
csv = StringIO("name\nR1\nR1\n")
913+
with raises(duckdb.ConstraintException, match=".*duplicate key.*"):
914+
read_regions_csv(csv, con)
915+
916+
917+
def test_read_new_commodities(populate_commodities):
903918
data = populate_commodities
904919
assert list(data["name"]) == ["electricity", "gas", "heat", "wind", "CO2f"]
905920
assert list(data["type"]) == ["energy"] * 5
@@ -921,7 +936,15 @@ def test_calculate_global_commodities(populate_commodities):
921936
assert list(data.data_vars["unit"].values) == list(populate_commodities["unit"])
922937

923938

924-
def test_read_new_global_commodities_type_constraint(default_new_input, con):
939+
def test_read_new_commodities_primary_key_constraint(default_new_input, con):
940+
from muse.new_input.readers import read_commodities_csv
941+
942+
csv = StringIO("name,type,unit\nfoo,energy,bar\nfoo,energy,bar\n")
943+
with raises(duckdb.ConstraintException, match=".*duplicate key.*"):
944+
read_commodities_csv(csv, con)
945+
946+
947+
def test_read_new_commodities_type_constraint(default_new_input, con):
925948
from muse.new_input.readers import read_commodities_csv
926949

927950
csv = StringIO("name,type,unit\nfoo,invalid,bar\n")
@@ -957,6 +980,32 @@ def test_new_read_demand_csv_region_constraint(
957980
read_demand_csv(csv, con)
958981

959982

983+
def test_new_read_demand_csv_primary_key_constraint(
984+
default_new_input, con, populate_commodities, populate_regions
985+
):
986+
from muse.new_input.readers import read_demand_csv, read_regions_csv
987+
988+
# Add another region so we can test varying it as a primary key
989+
csv = StringIO("name\nR2\n")
990+
read_regions_csv(csv, con)
991+
992+
# all fine so long as one primary key column differs
993+
csv = StringIO(
994+
"""year,commodity_name,region,demand
995+
2020,gas,R1,0
996+
2021,gas,R1,0
997+
2020,heat,R1,0
998+
2020,gas,R2,0
999+
"""
1000+
)
1001+
read_demand_csv(csv, con)
1002+
1003+
# no good if all primary key columns match a previous entry
1004+
csv = StringIO("year,commodity_name,region,demand\n2020,gas,R1,0")
1005+
with raises(duckdb.ConstraintException, match=".*duplicate key.*"):
1006+
read_demand_csv(csv, con)
1007+
1008+
9601009
@mark.xfail
9611010
def test_demand_dataset(default_new_input):
9621011
import duckdb

0 commit comments

Comments
 (0)