Skip to content
3 changes: 3 additions & 0 deletions functions-python/tasks_executor/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ google-cloud-storage
# Configuration
python-dotenv==1.0.0
pycountry

# Other utilities
pandas
7 changes: 6 additions & 1 deletion functions-python/tasks_executor/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import functions_framework

from shared.helpers.logger import init_logger
from tasks.data_import.transitfeeds.sync_transitfeeds import sync_transitfeeds_handler
from tasks.dataset_files.rebuild_missing_dataset_files import (
rebuild_missing_dataset_files_handler,
)
Expand All @@ -38,7 +39,7 @@
from tasks.geojson.update_geojson_files_precision import (
update_geojson_files_precision_handler,
)
from tasks.data_import.import_jbda_feeds import import_jbda_handler
from tasks.data_import.jbda.import_jbda_feeds import import_jbda_handler

from tasks.licenses.populate_license_rules import (
populate_license_rules_handler,
Expand Down Expand Up @@ -98,6 +99,10 @@
"description": "Populates licenses and license-rules in the database from a predefined JSON source.",
"handler": populate_licenses_handler,
},
"sync_transitfeeds_data": {
"description": "Syncs data from TransitFeeds to the database.",
"handler": sync_transitfeeds_handler,
},
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import logging
import uuid
from datetime import datetime
from typing import Tuple, Type, TypeVar

from sqlalchemy import select
from sqlalchemy.orm import Session

from shared.database_gen.sqlacodegen_models import (
Feed,
Officialstatushistory,
Entitytype,
)

logger = logging.getLogger(__name__)
T = TypeVar("T", bound="Feed")


def _get_or_create_entity_type(session: Session, entity_type_name: str) -> Entitytype:
"""Get or create an Entitytype by name."""
logger.debug("Looking up Entitytype name=%s", entity_type_name)
et = session.scalar(select(Entitytype).where(Entitytype.name == entity_type_name))
if et:
logger.debug("Found existing Entitytype name=%s", entity_type_name)
return et
et = Entitytype(name=entity_type_name)
session.add(et)
session.flush()
logger.info("Created Entitytype name=%s", entity_type_name)
return et


def get_feed(
session: Session,
stable_id: str,
model: Type[T] = Feed,
) -> T | None:
"""Get a Feed by stable_id."""
logger.debug("Lookup feed stable_id=%s", stable_id)
feed = session.scalar(select(model).where(model.stable_id == stable_id))
if feed:
logger.debug("Found existing feed stable_id=%s id=%s", stable_id, feed.id)
else:
logger.debug("No Feed found with stable_id=%s", stable_id)
return feed


def _get_or_create_feed(
session: Session,
model: Type[T],
stable_id: str,
data_type: str,
is_official: bool = True,
official_notes: str = "Imported from JBDA as official feed.",
reviewer_email: str = "[email protected]",
) -> Tuple[T, bool]:
"""Generic helper to get or create a Feed subclass (Gtfsfeed, Gtfsrealtimefeed) by stable_id."""
logger.debug(
"Lookup feed model=%s stable_id=%s",
getattr(model, "__name__", str(model)),
stable_id,
)
feed = session.scalar(select(model).where(model.stable_id == stable_id))
if feed:
logger.info(
"Found existing %s stable_id=%s id=%s",
getattr(model, "__name__", str(model)),
stable_id,
feed.id,
)
return feed, False

new_id = str(uuid.uuid4())
feed = model(
id=new_id,
data_type=data_type,
stable_id=stable_id,
official=is_official,
official_updated_at=datetime.now(),
)
if is_official:
feed.officialstatushistories = [
Officialstatushistory(
is_official=True,
reviewer_email=reviewer_email,
timestamp=datetime.now(),
notes=official_notes,
)
]
session.add(feed)
session.flush()
logger.info(
"Created %s stable_id=%s id=%s data_type=%s",
getattr(model, "__name__", str(model)),
stable_id,
new_id,
data_type,
)
return feed, True
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,28 @@
import os
import uuid
from datetime import datetime
from typing import Optional, Tuple, Dict, Any, List, Final, Type, TypeVar
from typing import Optional, Tuple, Dict, Any, List, Final, TypeVar

import requests
import pycountry
import requests
from sqlalchemy import select, and_
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session

from shared.common.locations_utils import create_or_get_location
from shared.database.database import with_db_session
from shared.database_gen.sqlacodegen_models import (
Feed,
Gtfsfeed,
Gtfsrealtimefeed,
Entitytype,
Feedrelatedlink,
Externalid,
Officialstatushistory,
)

from shared.helpers.pub_sub import trigger_dataset_download
from tasks.data_import.data_import_utils import (
_get_or_create_entity_type,
_get_or_create_feed,
)

T = TypeVar("T", bound="Feed")

Expand Down Expand Up @@ -99,20 +100,6 @@ def import_jbda_handler(payload: dict | None = None) -> dict:
return result


def _get_or_create_entity_type(session: Session, entity_type_name: str) -> Entitytype:
"""Get or create an Entitytype by name."""
logger.debug("Looking up Entitytype name=%s", entity_type_name)
et = session.scalar(select(Entitytype).where(Entitytype.name == entity_type_name))
if et:
logger.debug("Found existing Entitytype name=%s", entity_type_name)
return et
et = Entitytype(name=entity_type_name)
session.add(et)
session.flush()
logger.info("Created Entitytype name=%s", entity_type_name)
return et


def get_gtfs_file_url(
detail_body: Dict[str, Any], rid: str = "current"
) -> Optional[str]:
Expand Down Expand Up @@ -144,53 +131,6 @@ def get_gtfs_file_url(
return None


def _get_or_create_feed(
session: Session, model: Type[T], stable_id: str, data_type: str
) -> Tuple[T, bool]:
"""Generic helper to get or create a Feed subclass (Gtfsfeed, Gtfsrealtimefeed) by stable_id."""
logger.debug(
"Lookup feed model=%s stable_id=%s",
getattr(model, "__name__", str(model)),
stable_id,
)
feed = session.scalar(select(model).where(model.stable_id == stable_id))
if feed:
logger.info(
"Found existing %s stable_id=%s id=%s",
getattr(model, "__name__", str(model)),
stable_id,
feed.id,
)
return feed, False

new_id = str(uuid.uuid4())
feed = model(
id=new_id,
data_type=data_type,
stable_id=stable_id,
official=True,
official_updated_at=datetime.now(),
)
feed.officialstatushistories = [
Officialstatushistory(
is_official=True,
reviewer_email="[email protected]",
timestamp=datetime.now(),
notes="Imported from JBDA as official feed.",
)
]
session.add(feed)
session.flush()
logger.info(
"Created %s stable_id=%s id=%s data_type=%s",
getattr(model, "__name__", str(model)),
stable_id,
new_id,
data_type,
)
return feed, True


def _update_common_feed_fields(
feed: Feed, list_item: dict, detail: dict, producer_url: str
) -> None:
Expand Down
Loading