11import json
2+ from uuid import uuid4
23
34from geoalchemy2 import WKTElement
5+ from google .cloud .sql .connector .instance import logger
46from sqlalchemy import text
57
68from database .database import Database
7- from database_gen .sqlacodegen_models import Gtfsdataset , Validationreport , Gtfsfeed , Notice , Feature , t_feedsearch
8- from scripts .populate_db import set_up_configs
9+ from database_gen .sqlacodegen_models import (
10+ Gtfsdataset ,
11+ Validationreport ,
12+ Gtfsfeed ,
13+ Notice ,
14+ Feature ,
15+ t_feedsearch ,
16+ Location ,
17+ )
18+ from scripts .populate_db import set_up_configs , DatabasePopulateHelper
919from utils .logger import Logger
1020
1121
@@ -36,70 +46,78 @@ def populate_test_datasets(self, filepath):
3646 with open (filepath ) as f :
3747 data = json .load (f )
3848
49+ # GTFS Feeds
50+ if "feeds" in data :
51+ self .populate_test_feeds (data ["feeds" ])
52+
3953 # GTFS Datasets
4054 dataset_dict = {}
41- for dataset in data ["datasets" ]:
42- # query the db using feed_id to get the feed object
43- gtfsfeed = self .db .session .query (Gtfsfeed ).filter (Gtfsfeed .stable_id == dataset ["feed_stable_id" ]).all ()
44- if not gtfsfeed :
45- self .logger .error (f"No feed found with stable_id: { dataset ['feed_stable_id' ]} " )
46- continue
47-
48- gtfs_dataset = Gtfsdataset (
49- id = dataset ["id" ],
50- feed_id = gtfsfeed [0 ].id ,
51- stable_id = dataset ["id" ],
52- latest = dataset ["latest" ],
53- hosted_url = dataset ["hosted_url" ],
54- hash = dataset ["hash" ],
55- downloaded_at = dataset ["downloaded_at" ],
56- bounding_box = None
57- if dataset .get ("bounding_box" ) is None
58- else WKTElement (dataset ["bounding_box" ], srid = 4326 ),
59- validation_reports = [],
60- )
61- dataset_dict [dataset ["id" ]] = gtfs_dataset
62- self .db .session .add (gtfs_dataset )
55+ if "datasets" in data :
56+ for dataset in data ["datasets" ]:
57+ # query the db using feed_id to get the feed object
58+ gtfsfeed = self .db .session .query (Gtfsfeed ).filter (Gtfsfeed .stable_id == dataset ["feed_stable_id" ]).all ()
59+ if not gtfsfeed :
60+ self .logger .error (f"No feed found with stable_id: { dataset ['feed_stable_id' ]} " )
61+ continue
62+
63+ gtfs_dataset = Gtfsdataset (
64+ id = dataset ["id" ],
65+ feed_id = gtfsfeed [0 ].id ,
66+ stable_id = dataset ["id" ],
67+ latest = dataset ["latest" ],
68+ hosted_url = dataset ["hosted_url" ],
69+ hash = dataset ["hash" ],
70+ downloaded_at = dataset ["downloaded_at" ],
71+ bounding_box = None
72+ if dataset .get ("bounding_box" ) is None
73+ else WKTElement (dataset ["bounding_box" ], srid = 4326 ),
74+ validation_reports = [],
75+ )
76+ dataset_dict [dataset ["id" ]] = gtfs_dataset
77+ self .db .session .add (gtfs_dataset )
6378
6479 # Validation reports
65- validation_report_dict = {}
66- for report in data ["validation_reports" ]:
67- validation_report = Validationreport (
68- id = report ["id" ],
69- validator_version = report ["validator_version" ],
70- validated_at = report ["validated_at" ],
71- html_report = report ["html_report" ],
72- json_report = report ["json_report" ],
73- features = [],
74- )
75- dataset_dict [report ["dataset_id" ]].validation_reports .append (validation_report )
76- validation_report_dict [report ["id" ]] = validation_report
77- self .db .session .add (validation_report )
80+ if "validation_reports" in data :
81+ validation_report_dict = {}
82+ for report in data ["validation_reports" ]:
83+ validation_report = Validationreport (
84+ id = report ["id" ],
85+ validator_version = report ["validator_version" ],
86+ validated_at = report ["validated_at" ],
87+ html_report = report ["html_report" ],
88+ json_report = report ["json_report" ],
89+ features = [],
90+ )
91+ dataset_dict [report ["dataset_id" ]].validation_reports .append (validation_report )
92+ validation_report_dict [report ["id" ]] = validation_report
93+ self .db .session .add (validation_report )
7894
7995 # Notices
80- for report_notice in data ["notices" ]:
81- notice = Notice (
82- dataset_id = report_notice ["dataset_id" ],
83- validation_report_id = report_notice ["validation_report_id" ],
84- severity = report_notice ["severity" ],
85- notice_code = report_notice ["notice_code" ],
86- total_notices = report_notice ["total_notices" ],
87- )
88- self .db .session .add (notice )
96+ if "notices" in data :
97+ for report_notice in data ["notices" ]:
98+ notice = Notice (
99+ dataset_id = report_notice ["dataset_id" ],
100+ validation_report_id = report_notice ["validation_report_id" ],
101+ severity = report_notice ["severity" ],
102+ notice_code = report_notice ["notice_code" ],
103+ total_notices = report_notice ["total_notices" ],
104+ )
105+ self .db .session .add (notice )
89106 # Features
90- for featureName in data ["features" ]:
91- feature = Feature (name = featureName )
92- self .db .session .add (feature )
107+ if "features" in data :
108+ for featureName in data ["features" ]:
109+ feature = Feature (name = featureName )
110+ self .db .session .add (feature )
93111
94112 # Features in Validation Reports
95- for report_features in data ["validation_report_features" ]:
96- validation_report_dict [report_features ["validation_report_id" ]].features .append (
97- self .db .session .query (Feature ).filter (Feature .name == report_features ["feature_name" ]).first ()
98- )
99-
100- self .db .session .execute (text (f"REFRESH MATERIALIZED VIEW CONCURRENTLY { t_feedsearch .name } " ))
113+ if "validation_report_features" in data :
114+ for report_features in data ["validation_report_features" ]:
115+ validation_report_dict [report_features ["validation_report_id" ]].features .append (
116+ self .db .session .query (Feature ).filter (Feature .name == report_features ["feature_name" ]).first ()
117+ )
101118
102119 self .db .session .commit ()
120+ self .db .session .execute (text (f"REFRESH MATERIALIZED VIEW CONCURRENTLY { t_feedsearch .name } " ))
103121
104122 def populate (self ):
105123 """
@@ -116,6 +134,47 @@ def populate(self):
116134
117135 self .logger .info ("Database populated with test data" )
118136
137+ def populate_test_feeds (self , feeds_data ):
138+ for feed_data in feeds_data :
139+ feed = Gtfsfeed (
140+ id = str (uuid4 ()),
141+ stable_id = feed_data ["id" ],
142+ data_type = feed_data ["data_type" ],
143+ status = feed_data ["status" ],
144+ created_at = feed_data ["created_at" ],
145+ provider = feed_data ["provider" ],
146+ feed_name = feed_data ["feed_name" ],
147+ note = feed_data ["note" ],
148+ authentication_info_url = None ,
149+ api_key_parameter_name = None ,
150+ license_url = None ,
151+ feed_contact_email = feed_data ["feed_contact_email" ],
152+ producer_url = feed_data ["source_info" ]["producer_url" ],
153+ )
154+ locations = []
155+ for location_data in feed_data ["locations" ]:
156+ location_id = DatabasePopulateHelper .get_location_id (
157+ location_data ["country_code" ],
158+ location_data ["subdivision_name" ],
159+ location_data ["municipality" ],
160+ )
161+ location = self .db .session .get (Location , location_id )
162+ location = (
163+ location
164+ if location
165+ else Location (
166+ id = location_id ,
167+ country_code = location_data ["country_code" ],
168+ subdivision_name = location_data ["subdivision_name" ],
169+ municipality = location_data ["municipality" ],
170+ country = location_data ["country" ],
171+ )
172+ )
173+ locations .append (location )
174+ feed .locations = locations
175+ self .db .session .add (feed )
176+ logger .info (f"Added feed { feed .stable_id } " )
177+
119178
120179if __name__ == "__main__" :
121180 db_helper = DatabasePopulateTestDataHelper (set_up_configs ())
0 commit comments