Skip to content

Commit 2d29bc7

Browse files
cc-atsmbland
authored andcommitted
Adopt generic read_csv function
1 parent 5b80c67 commit 2d29bc7

File tree

2 files changed

+34
-38
lines changed

2 files changed

+34
-38
lines changed

src/muse/new_input/readers.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ class Demand(TableBase):
2929

3030
year: Mapped[int] = mapped_column(primary_key=True, autoincrement=False)
3131
commodity: Mapped[Commodities] = mapped_column(
32-
ForeignKey("commodities.name"), primary_key=True
32+
ForeignKey("commodities.name"),
33+
primary_key=True,
34+
info=dict(header="commodity_name"),
3335
)
3436
region: Mapped[Regions] = mapped_column(
3537
ForeignKey("regions.name"), primary_key=True
@@ -45,29 +47,29 @@ def read_inputs(data_dir):
4547
con = engine.raw_connection().driver_connection
4648

4749
with open(data_dir / "regions.csv") as f:
48-
regions = read_regions_csv(f, con) # noqa: F841
50+
regions = read_csv(f, Regions, con) # noqa: F841
4951

5052
with open(data_dir / "commodities.csv") as f:
51-
commodities = read_commodities_csv(f, con)
53+
commodities = read_csv(f, Commodities, con)
5254

5355
with open(data_dir / "demand.csv") as f:
54-
demand = read_demand_csv(f, con) # noqa: F841
56+
demand = read_csv(f, Demand, con) # noqa: F841
5557

5658
data = {}
5759
data["global_commodities"] = calculate_global_commodities(commodities)
5860
return data
5961

6062

61-
def read_regions_csv(buffer_, con):
62-
rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841
63-
con.execute("INSERT INTO regions SELECT name FROM rel;")
64-
return con.sql("SELECT name from regions").fetchnumpy()
65-
63+
def read_csv(buffer_, table_class, con):
64+
table_name = table_class.__tablename__
65+
columns = ", ".join(
66+
column.info.get("header", column.name)
67+
for column in table_class.__table__.columns
68+
)
6669

67-
def read_commodities_csv(buffer_, con):
6870
rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841
69-
con.sql("INSERT INTO commodities SELECT name, type, unit FROM rel;")
70-
return con.sql("select name, type, unit from commodities").fetchnumpy()
71+
con.execute(f"INSERT INTO {table_name} SELECT {columns} FROM rel")
72+
return con.execute(f"SELECT * from {table_name}").fetchnumpy()
7173

7274

7375
def calculate_global_commodities(commodities):
@@ -85,9 +87,3 @@ def calculate_global_commodities(commodities):
8587

8688
data = xr.Dataset(data_vars=dict(type=type_array, unit=unit_array))
8789
return data
88-
89-
90-
def read_demand_csv(buffer_, con):
91-
rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841
92-
con.sql("INSERT INTO demand SELECT year, commodity_name, region, demand FROM rel;")
93-
return con.sql("SELECT * from demand").fetchnumpy()

tests/test_readers.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -341,38 +341,38 @@ def con():
341341

342342
@fixture
343343
def populate_regions(default_new_input, con):
344-
from muse.new_input.readers import read_regions_csv
344+
from muse.new_input.readers import Regions, read_csv
345345

346346
with open(default_new_input / "regions.csv") as f:
347-
return read_regions_csv(f, con)
347+
return read_csv(f, Regions, con)
348348

349349

350350
@fixture
351351
def populate_commodities(default_new_input, con):
352-
from muse.new_input.readers import read_commodities_csv
352+
from muse.new_input.readers import Commodities, read_csv
353353

354354
with open(default_new_input / "commodities.csv") as f:
355-
return read_commodities_csv(f, con)
355+
return read_csv(f, Commodities, con)
356356

357357

358358
@fixture
359359
def populate_demand(default_new_input, con, populate_regions, populate_commodities):
360-
from muse.new_input.readers import read_demand_csv
360+
from muse.new_input.readers import Demand, read_csv
361361

362362
with open(default_new_input / "demand.csv") as f:
363-
return read_demand_csv(f, con)
363+
return read_csv(f, Demand, con)
364364

365365

366366
def test_read_regions(populate_regions):
367367
assert populate_regions["name"] == np.array(["R1"])
368368

369369

370370
def test_read_regions_primary_key_constraint(default_new_input, con):
371-
from muse.new_input.readers import read_regions_csv
371+
from muse.new_input.readers import Regions, read_csv
372372

373373
csv = StringIO("name\nR1\nR1\n")
374374
with raises(duckdb.ConstraintException, match=".*duplicate key.*"):
375-
read_regions_csv(csv, con)
375+
read_csv(csv, Regions, con)
376376

377377

378378
def test_read_new_commodities(populate_commodities):
@@ -398,19 +398,19 @@ def test_calculate_global_commodities(populate_commodities):
398398

399399

400400
def test_read_new_commodities_primary_key_constraint(default_new_input, con):
401-
from muse.new_input.readers import read_commodities_csv
401+
from muse.new_input.readers import Commodities, read_csv
402402

403403
csv = StringIO("name,type,unit\nfoo,energy,bar\nfoo,energy,bar\n")
404404
with raises(duckdb.ConstraintException, match=".*duplicate key.*"):
405-
read_commodities_csv(csv, con)
405+
read_csv(csv, Commodities, con)
406406

407407

408408
def test_read_new_commodities_type_constraint(default_new_input, con):
409-
from muse.new_input.readers import read_commodities_csv
409+
from muse.new_input.readers import Commodities, read_csv
410410

411411
csv = StringIO("name,type,unit\nfoo,invalid,bar\n")
412412
with raises(duckdb.ConstraintException):
413-
read_commodities_csv(csv, con)
413+
read_csv(csv, Commodities, con)
414414

415415

416416
def test_new_read_demand_csv(populate_demand):
@@ -424,31 +424,31 @@ def test_new_read_demand_csv(populate_demand):
424424
def test_new_read_demand_csv_commodity_constraint(
425425
default_new_input, con, populate_commodities, populate_regions
426426
):
427-
from muse.new_input.readers import read_demand_csv
427+
from muse.new_input.readers import Demand, read_csv
428428

429429
csv = StringIO("year,commodity_name,region,demand\n2020,invalid,R1,0\n")
430430
with raises(duckdb.ConstraintException, match=".*foreign key.*"):
431-
read_demand_csv(csv, con)
431+
read_csv(csv, Demand, con)
432432

433433

434434
def test_new_read_demand_csv_region_constraint(
435435
default_new_input, con, populate_commodities, populate_regions
436436
):
437-
from muse.new_input.readers import read_demand_csv
437+
from muse.new_input.readers import Demand, read_csv
438438

439439
csv = StringIO("year,commodity_name,region,demand\n2020,heat,invalid,0\n")
440440
with raises(duckdb.ConstraintException, match=".*foreign key.*"):
441-
read_demand_csv(csv, con)
441+
read_csv(csv, Demand, con)
442442

443443

444444
def test_new_read_demand_csv_primary_key_constraint(
445445
default_new_input, con, populate_commodities, populate_regions
446446
):
447-
from muse.new_input.readers import read_demand_csv, read_regions_csv
447+
from muse.new_input.readers import Demand, Regions, read_csv
448448

449449
# Add another region so we can test varying it as a primary key
450450
csv = StringIO("name\nR2\n")
451-
read_regions_csv(csv, con)
451+
read_csv(csv, Regions, con)
452452

453453
# all fine so long as one primary key column differs
454454
csv = StringIO(
@@ -459,12 +459,12 @@ def test_new_read_demand_csv_primary_key_constraint(
459459
2020,gas,R2,0
460460
"""
461461
)
462-
read_demand_csv(csv, con)
462+
read_csv(csv, Demand, con)
463463

464464
# no good if all primary key columns match a previous entry
465465
csv = StringIO("year,commodity_name,region,demand\n2020,gas,R1,0")
466466
with raises(duckdb.ConstraintException, match=".*duplicate key.*"):
467-
read_demand_csv(csv, con)
467+
read_csv(csv, Demand, con)
468468

469469

470470
@mark.xfail

0 commit comments

Comments
 (0)