Skip to content

Commit cf7aa3e

Browse files
cc-atsmbland
authored andcommitted
Define db schema using SQLAlchemy
1 parent 560007d commit cf7aa3e

File tree

3 files changed

+97
-29
lines changed

3 files changed

+97
-29
lines changed

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ dependencies = [
3737
"pypubsub",
3838
"tomlkit",
3939
"duckdb",
40-
"fsspec"
40+
"fsspec",
41+
"sqlalchemy",
42+
"duckdb-engine"
4143
]
4244
dynamic = ["version"]
4345

src/muse/new_input/readers.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,48 @@
1-
import duckdb
21
import numpy as np
32
import xarray as xr
3+
from sqlalchemy import CheckConstraint, ForeignKey
4+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
5+
6+
7+
class TableBase(DeclarativeBase):
8+
pass
9+
10+
11+
class Regions(TableBase):
12+
__tablename__ = "regions"
13+
14+
name: Mapped[str] = mapped_column(primary_key=True)
15+
16+
17+
class Commodities(TableBase):
18+
__tablename__ = "commodities"
19+
20+
name: Mapped[str] = mapped_column(primary_key=True)
21+
type: Mapped[str] = mapped_column(
22+
CheckConstraint("type IN ('energy', 'service', 'material', 'environmental')")
23+
)
24+
unit: Mapped[str]
25+
26+
27+
class Demand(TableBase):
28+
__tablename__ = "demand"
29+
30+
year: Mapped[int] = mapped_column(primary_key=True, autoincrement=False)
31+
commodity: Mapped[Commodities] = mapped_column(
32+
ForeignKey("commodities.name"), primary_key=True
33+
)
34+
region: Mapped[Regions] = mapped_column(
35+
ForeignKey("regions.name"), primary_key=True
36+
)
37+
demand: Mapped[float]
438

539

640
def read_inputs(data_dir):
7-
data = {}
8-
con = duckdb.connect(":memory:")
41+
from sqlalchemy import create_engine
42+
43+
engine = create_engine("duckdb:///:memory:")
44+
TableBase.metadata.create_all(engine)
45+
con = engine.raw_connection().driver_connection
946

1047
with open(data_dir / "regions.csv") as f:
1148
regions = read_regions_csv(f, con) # noqa: F841
@@ -16,32 +53,20 @@ def read_inputs(data_dir):
1653
with open(data_dir / "demand.csv") as f:
1754
demand = read_demand_csv(f, con) # noqa: F841
1855

56+
data = {}
1957
data["global_commodities"] = calculate_global_commodities(commodities)
2058
return data
2159

2260

2361
def read_regions_csv(buffer_, con):
24-
sql = """CREATE TABLE regions (
25-
name VARCHAR PRIMARY KEY,
26-
);
27-
"""
28-
con.sql(sql)
2962
rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841
30-
con.sql("INSERT INTO regions SELECT name FROM rel;")
63+
con.execute("INSERT INTO regions SELECT name FROM rel;")
3164
return con.sql("SELECT name from regions").fetchnumpy()
3265

3366

3467
def read_commodities_csv(buffer_, con):
35-
sql = """CREATE TABLE commodities (
36-
name VARCHAR PRIMARY KEY,
37-
type VARCHAR CHECK (type IN ('energy', 'service', 'material', 'environmental')),
38-
unit VARCHAR,
39-
);
40-
"""
41-
con.sql(sql)
4268
rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841
4369
con.sql("INSERT INTO commodities SELECT name, type, unit FROM rel;")
44-
4570
return con.sql("select name, type, unit from commodities").fetchnumpy()
4671

4772

@@ -63,14 +88,6 @@ def calculate_global_commodities(commodities):
6388

6489

6590
def read_demand_csv(buffer_, con):
66-
sql = """CREATE TABLE demand (
67-
year BIGINT,
68-
commodity VARCHAR REFERENCES commodities(name),
69-
region VARCHAR REFERENCES regions(name),
70-
demand DOUBLE,
71-
);
72-
"""
73-
con.sql(sql)
7491
rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841
7592
con.sql("INSERT INTO demand SELECT year, commodity_name, region, demand FROM rel;")
7693
return con.sql("SELECT * from demand").fetchnumpy()

tests/test_readers.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,14 @@ def default_new_input(tmp_path):
329329

330330
@fixture
331331
def con():
332-
return duckdb.connect(":memory:")
332+
from muse.new_input.readers import TableBase
333+
from sqlalchemy import create_engine
334+
from sqlalchemy.orm import Session
335+
336+
engine = create_engine("duckdb:///:memory:")
337+
session = Session(engine)
338+
TableBase.metadata.create_all(engine)
339+
return session.connection().connection
333340

334341

335342
@fixture
@@ -360,7 +367,15 @@ def test_read_regions(populate_regions):
360367
assert populate_regions["name"] == np.array(["R1"])
361368

362369

363-
def test_read_new_global_commodities(populate_commodities):
370+
def test_read_regions_primary_key_constraint(default_new_input, con):
371+
from muse.new_input.readers import read_regions_csv
372+
373+
csv = StringIO("name\nR1\nR1\n")
374+
with raises(duckdb.ConstraintException, match=".*duplicate key.*"):
375+
read_regions_csv(csv, con)
376+
377+
378+
def test_read_new_commodities(populate_commodities):
364379
data = populate_commodities
365380
assert list(data["name"]) == ["electricity", "gas", "heat", "wind", "CO2f"]
366381
assert list(data["type"]) == ["energy"] * 5
@@ -382,7 +397,15 @@ def test_calculate_global_commodities(populate_commodities):
382397
assert list(data.data_vars["unit"].values) == list(populate_commodities["unit"])
383398

384399

385-
def test_read_new_global_commodities_type_constraint(default_new_input, con):
400+
def test_read_new_commodities_primary_key_constraint(default_new_input, con):
401+
from muse.new_input.readers import read_commodities_csv
402+
403+
csv = StringIO("name,type,unit\nfoo,energy,bar\nfoo,energy,bar\n")
404+
with raises(duckdb.ConstraintException, match=".*duplicate key.*"):
405+
read_commodities_csv(csv, con)
406+
407+
408+
def test_read_new_commodities_type_constraint(default_new_input, con):
386409
from muse.new_input.readers import read_commodities_csv
387410

388411
csv = StringIO("name,type,unit\nfoo,invalid,bar\n")
@@ -418,6 +441,32 @@ def test_new_read_demand_csv_region_constraint(
418441
read_demand_csv(csv, con)
419442

420443

444+
def test_new_read_demand_csv_primary_key_constraint(
445+
default_new_input, con, populate_commodities, populate_regions
446+
):
447+
from muse.new_input.readers import read_demand_csv, read_regions_csv
448+
449+
# Add another region so we can test varying it as a primary key
450+
csv = StringIO("name\nR2\n")
451+
read_regions_csv(csv, con)
452+
453+
# all fine so long as one primary key column differs
454+
csv = StringIO(
455+
"""year,commodity_name,region,demand
456+
2020,gas,R1,0
457+
2021,gas,R1,0
458+
2020,heat,R1,0
459+
2020,gas,R2,0
460+
"""
461+
)
462+
read_demand_csv(csv, con)
463+
464+
# no good if all primary key columns match a previous entry
465+
csv = StringIO("year,commodity_name,region,demand\n2020,gas,R1,0")
466+
with raises(duckdb.ConstraintException, match=".*duplicate key.*"):
467+
read_demand_csv(csv, con)
468+
469+
421470
@mark.xfail
422471
def test_demand_dataset(default_new_input):
423472
import duckdb

0 commit comments

Comments
 (0)