Skip to content

Commit 606dd66

Browse files
committed
Adopt generic read_csv function
1 parent 6fe0104 commit 606dd66

File tree

2 files changed

+34
-38
lines changed

2 files changed

+34
-38
lines changed

src/muse/new_input/readers.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ class Demand(TableBase):
2929

3030
year: Mapped[int] = mapped_column(primary_key=True, autoincrement=False)
3131
commodity: Mapped[Commodities] = mapped_column(
32-
ForeignKey("commodities.name"), primary_key=True
32+
ForeignKey("commodities.name"),
33+
primary_key=True,
34+
info=dict(header="commodity_name"),
3335
)
3436
region: Mapped[Regions] = mapped_column(
3537
ForeignKey("regions.name"), primary_key=True
@@ -45,29 +47,29 @@ def read_inputs(data_dir):
4547
con = engine.raw_connection().driver_connection
4648

4749
with open(data_dir / "regions.csv") as f:
48-
regions = read_regions_csv(f, con) # noqa: F841
50+
regions = read_csv(f, Regions, con) # noqa: F841
4951

5052
with open(data_dir / "commodities.csv") as f:
51-
commodities = read_commodities_csv(f, con)
53+
commodities = read_csv(f, Commodities, con)
5254

5355
with open(data_dir / "demand.csv") as f:
54-
demand = read_demand_csv(f, con) # noqa: F841
56+
demand = read_csv(f, Demand, con) # noqa: F841
5557

5658
data = {}
5759
data["global_commodities"] = calculate_global_commodities(commodities)
5860
return data
5961

6062

61-
def read_regions_csv(buffer_, con):
62-
rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841
63-
con.execute("INSERT INTO regions SELECT name FROM rel;")
64-
return con.sql("SELECT name from regions").fetchnumpy()
65-
63+
def read_csv(buffer_, table_class, con):
64+
table_name = table_class.__tablename__
65+
columns = ", ".join(
66+
column.info.get("header", column.name)
67+
for column in table_class.__table__.columns
68+
)
6669

67-
def read_commodities_csv(buffer_, con):
6870
rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841
69-
con.sql("INSERT INTO commodities SELECT name, type, unit FROM rel;")
70-
return con.sql("select name, type, unit from commodities").fetchnumpy()
71+
con.execute(f"INSERT INTO {table_name} SELECT {columns} FROM rel")
72+
return con.execute(f"SELECT * from {table_name}").fetchnumpy()
7173

7274

7375
def calculate_global_commodities(commodities):
@@ -85,9 +87,3 @@ def calculate_global_commodities(commodities):
8587

8688
data = xr.Dataset(data_vars=dict(type=type_array, unit=unit_array))
8789
return data
88-
89-
90-
def read_demand_csv(buffer_, con):
91-
rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841
92-
con.sql("INSERT INTO demand SELECT year, commodity_name, region, demand FROM rel;")
93-
return con.sql("SELECT * from demand").fetchnumpy()

tests/test_readers.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -880,38 +880,38 @@ def con():
880880

881881
@fixture
882882
def populate_regions(default_new_input, con):
883-
from muse.new_input.readers import read_regions_csv
883+
from muse.new_input.readers import Regions, read_csv
884884

885885
with open(default_new_input / "regions.csv") as f:
886-
return read_regions_csv(f, con)
886+
return read_csv(f, Regions, con)
887887

888888

889889
@fixture
890890
def populate_commodities(default_new_input, con):
891-
from muse.new_input.readers import read_commodities_csv
891+
from muse.new_input.readers import Commodities, read_csv
892892

893893
with open(default_new_input / "commodities.csv") as f:
894-
return read_commodities_csv(f, con)
894+
return read_csv(f, Commodities, con)
895895

896896

897897
@fixture
898898
def populate_demand(default_new_input, con, populate_regions, populate_commodities):
899-
from muse.new_input.readers import read_demand_csv
899+
from muse.new_input.readers import Demand, read_csv
900900

901901
with open(default_new_input / "demand.csv") as f:
902-
return read_demand_csv(f, con)
902+
return read_csv(f, Demand, con)
903903

904904

905905
def test_read_regions(populate_regions):
906906
assert populate_regions["name"] == np.array(["R1"])
907907

908908

909909
def test_read_regions_primary_key_constraint(default_new_input, con):
910-
from muse.new_input.readers import read_regions_csv
910+
from muse.new_input.readers import Regions, read_csv
911911

912912
csv = StringIO("name\nR1\nR1\n")
913913
with raises(duckdb.ConstraintException, match=".*duplicate key.*"):
914-
read_regions_csv(csv, con)
914+
read_csv(csv, Regions, con)
915915

916916

917917
def test_read_new_commodities(populate_commodities):
@@ -937,19 +937,19 @@ def test_calculate_global_commodities(populate_commodities):
937937

938938

939939
def test_read_new_commodities_primary_key_constraint(default_new_input, con):
940-
from muse.new_input.readers import read_commodities_csv
940+
from muse.new_input.readers import Commodities, read_csv
941941

942942
csv = StringIO("name,type,unit\nfoo,energy,bar\nfoo,energy,bar\n")
943943
with raises(duckdb.ConstraintException, match=".*duplicate key.*"):
944-
read_commodities_csv(csv, con)
944+
read_csv(csv, Commodities, con)
945945

946946

947947
def test_read_new_commodities_type_constraint(default_new_input, con):
948-
from muse.new_input.readers import read_commodities_csv
948+
from muse.new_input.readers import Commodities, read_csv
949949

950950
csv = StringIO("name,type,unit\nfoo,invalid,bar\n")
951951
with raises(duckdb.ConstraintException):
952-
read_commodities_csv(csv, con)
952+
read_csv(csv, Commodities, con)
953953

954954

955955
def test_new_read_demand_csv(populate_demand):
@@ -963,31 +963,31 @@ def test_new_read_demand_csv(populate_demand):
963963
def test_new_read_demand_csv_commodity_constraint(
964964
default_new_input, con, populate_commodities, populate_regions
965965
):
966-
from muse.new_input.readers import read_demand_csv
966+
from muse.new_input.readers import Demand, read_csv
967967

968968
csv = StringIO("year,commodity_name,region,demand\n2020,invalid,R1,0\n")
969969
with raises(duckdb.ConstraintException, match=".*foreign key.*"):
970-
read_demand_csv(csv, con)
970+
read_csv(csv, Demand, con)
971971

972972

973973
def test_new_read_demand_csv_region_constraint(
974974
default_new_input, con, populate_commodities, populate_regions
975975
):
976-
from muse.new_input.readers import read_demand_csv
976+
from muse.new_input.readers import Demand, read_csv
977977

978978
csv = StringIO("year,commodity_name,region,demand\n2020,heat,invalid,0\n")
979979
with raises(duckdb.ConstraintException, match=".*foreign key.*"):
980-
read_demand_csv(csv, con)
980+
read_csv(csv, Demand, con)
981981

982982

983983
def test_new_read_demand_csv_primary_key_constraint(
984984
default_new_input, con, populate_commodities, populate_regions
985985
):
986-
from muse.new_input.readers import read_demand_csv, read_regions_csv
986+
from muse.new_input.readers import Demand, Regions, read_csv
987987

988988
# Add another region so we can test varying it as a primary key
989989
csv = StringIO("name\nR2\n")
990-
read_regions_csv(csv, con)
990+
read_csv(csv, Regions, con)
991991

992992
# all fine so long as one primary key column differs
993993
csv = StringIO(
@@ -998,12 +998,12 @@ def test_new_read_demand_csv_primary_key_constraint(
998998
2020,gas,R2,0
999999
"""
10001000
)
1001-
read_demand_csv(csv, con)
1001+
read_csv(csv, Demand, con)
10021002

10031003
# no good if all primary key columns match a previous entry
10041004
csv = StringIO("year,commodity_name,region,demand\n2020,gas,R1,0")
10051005
with raises(duckdb.ConstraintException, match=".*duplicate key.*"):
1006-
read_demand_csv(csv, con)
1006+
read_csv(csv, Demand, con)
10071007

10081008

10091009
@mark.xfail

0 commit comments

Comments
 (0)