Skip to content

Commit d1d68bf

Browse files
Run local ruff formatting and set continue-on-error to false in test pipeline
1 parent 3d73ef6 commit d1d68bf

File tree

10 files changed

+228
-174
lines changed

10 files changed

+228
-174
lines changed

.github/workflows/test_and_deploy.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
3636
- name: Check code formatting with Ruff
3737
run: ruff format --diff --target-version=py39
38-
continue-on-error: true
38+
continue-on-error: false
3939

4040
- name: Run tests and generate HTML report
4141
run: |

dags/fetch.py

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,70 +6,82 @@
66
from airflow.operators.dagrun import TriggerDagRunOperator
77
from airflow.operators.empty import EmptyOperator
88
from airflow.operators.python import PythonOperator
9-
from airflow.providers.amazon.aws.operators.lambda_function import AwsLambdaInvokeFunctionOperator
9+
from airflow.providers.amazon.aws.operators.lambda_function import (
10+
AwsLambdaInvokeFunctionOperator,
11+
)
1012

11-
lambda_function_fetch_name = os.getenv('LAMBDA_FUNCTION_FETCH_NAME')
12-
lambda_function_validate_name = os.getenv('LAMBDA_FUNCTION_VALIDATE_NAME')
13+
lambda_function_fetch_name = os.getenv("LAMBDA_FUNCTION_FETCH_NAME")
14+
lambda_function_validate_name = os.getenv("LAMBDA_FUNCTION_VALIDATE_NAME")
1315

1416
default_args = {
15-
'owner': 'Billy Moore',
16-
'retries': 1,
17-
'retry_delay': datetime.timedelta(minutes=1),
18-
'depends_on_past': False,
19-
'email_on_failure': False,
20-
'email_on_retry': False
17+
"owner": "Billy Moore",
18+
"retries": 1,
19+
"retry_delay": datetime.timedelta(minutes=1),
20+
"depends_on_past": False,
21+
"email_on_failure": False,
22+
"email_on_retry": False,
2123
}
2224

25+
2326
# callable to check the result of lambda functions
2427
def check_lambda_result(task_id, **context):
25-
result = context['ti'].xcom_pull(task_ids=task_id)
28+
result = context["ti"].xcom_pull(task_ids=task_id)
2629
if result is None:
27-
raise ValueError(f'Lambda function {task_id} failed to return a result.')
28-
payload = result.get('Payload')
30+
raise ValueError(f"Lambda function {task_id} failed to return a result.")
31+
payload = result.get("Payload")
2932
if payload:
3033
response = json.loads(payload.read())
31-
if response.get('StatusCode') != 200:
32-
raise ValueError(f'Lambda function {task_id} failed with status code: {response.get('StatusCode')}')
34+
if response.get("StatusCode") != 200:
35+
raise ValueError(
36+
f"Lambda function {task_id} failed with status code: {response.get('StatusCode')}"
37+
)
3338
else:
34-
raise ValueError(f'Lambda function {task_id} returned no payload.')
39+
raise ValueError(f"Lambda function {task_id} returned no payload.")
40+
3541

3642
with DAG(
37-
dag_id='fetch',
43+
dag_id="fetch",
3844
default_args=default_args,
3945
catchup=False,
40-
schedule_interval='@hourly'
46+
schedule_interval="@hourly",
4147
) as dag:
42-
43-
initiate = EmptyOperator(task_id='initiate')
48+
initiate = EmptyOperator(task_id="initiate")
4449

4550
lambda_fetch = AwsLambdaInvokeFunctionOperator(
46-
task_id='lambda_fetch',
51+
task_id="lambda_fetch",
4752
function_name=lambda_function_fetch_name,
48-
payload={'filename': 'market_data_{{ ts_nodash }}.json'}
53+
payload={"filename": "market_data_{{ ts_nodash }}.json"},
4954
)
5055

5156
check_fetch = PythonOperator(
52-
task_id='check_fetch',
57+
task_id="check_fetch",
5358
python_callable=check_lambda_result,
54-
op_kwargs={'task_id': 'lambda_fetch'}
59+
op_kwargs={"task_id": "lambda_fetch"},
5560
)
5661

5762
lamda_validate = AwsLambdaInvokeFunctionOperator(
58-
task_id='lambda_validate',
63+
task_id="lambda_validate",
5964
function_name=lambda_function_validate_name,
60-
payload={'filename': 'market_data_{{ ts_nodash }}.json'}
65+
payload={"filename": "market_data_{{ ts_nodash }}.json"},
6166
)
6267

6368
check_validate = PythonOperator(
64-
task_id='check_validate',
69+
task_id="check_validate",
6570
python_callable=check_lambda_result,
66-
op_kwargs={'task_id': 'lambda_validate'}
71+
op_kwargs={"task_id": "lambda_validate"},
6772
)
6873

6974
trigger_snowflake_ingestion = TriggerDagRunOperator(
70-
task_id='trigger_snowflake_ingestion',
71-
trigger_dag_id='ingest',
72-
conf={'execution_timestamp': '{{ ts_nodash }}'},
75+
task_id="trigger_snowflake_ingestion",
76+
trigger_dag_id="ingest",
77+
conf={"execution_timestamp": "{{ ts_nodash }}"},
7378
)
7479

75-
initiate >> lambda_fetch >> check_fetch >> lamda_validate >> check_validate >> trigger_snowflake_ingestion
80+
(
81+
initiate
82+
>> lambda_fetch
83+
>> check_fetch
84+
>> lamda_validate
85+
>> check_validate
86+
>> trigger_snowflake_ingestion
87+
)

dags/ingest.py

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,46 +5,62 @@
55
from pathlib import Path
66

77
default_args = {
8-
'owner': 'Billy Moore',
9-
'retries': 1,
10-
'retry_delay': datetime.timedelta(minutes=1),
11-
'depends_on_past': False,
12-
'email_on_failure': False,
13-
'email_on_retry': False
8+
"owner": "Billy Moore",
9+
"retries": 1,
10+
"retry_delay": datetime.timedelta(minutes=1),
11+
"depends_on_past": False,
12+
"email_on_failure": False,
13+
"email_on_retry": False,
1414
}
1515

16+
1617
def read_sql_query(dir: str, name: str) -> str:
1718
root_dir = Path(__file__).parent.parent
18-
sql_path = root_dir / 'sql' / dir / name
19-
with open(sql_path, 'r') as f:
19+
sql_path = root_dir / "sql" / dir / name
20+
with open(sql_path, "r") as f:
2021
sql = f.read()
2122
return sql
2223

23-
with DAG(
24-
dag_id='ingest',
25-
default_args=default_args
26-
):
24+
25+
with DAG(dag_id="ingest", default_args=default_args):
2726
execution_timestamp = "{{ dag_run.conf['execution_timestamp'] }}"
2827

29-
def snowflake_task_factory(task_id: str, filename: str, timestamp_param: bool = False):
28+
def snowflake_task_factory(
29+
task_id: str, filename: str, timestamp_param: bool = False
30+
):
3031
return SnowflakeOperator(
3132
task_id=task_id,
32-
sql=read_sql_query('loading', filename),
33-
snowflake_conn_id='snowflake_predictit',
34-
params={
35-
'execution_timestamp': execution_timestamp
36-
} if timestamp_param else {}
33+
sql=read_sql_query("loading", filename),
34+
snowflake_conn_id="snowflake_predictit",
35+
params={"execution_timestamp": execution_timestamp}
36+
if timestamp_param
37+
else {},
3738
)
3839

39-
load_stage_raw = snowflake_task_factory('load_stage_raw', 'load_stage_raw.sql', timestamp_param=True)
40-
load_stg_dim_markets = snowflake_task_factory('load_stg_dim_markets', 'load_stg_dim_markets.sql')
41-
load_dim_markets = snowflake_task_factory('load_dim_markets', 'load_dim_markets.sql')
42-
load_stg_dim_contracts = snowflake_task_factory('load_stg_dim_contracts', 'load_stg_dim_contracts.sql')
43-
load_dim_contracts = snowflake_task_factory('load_dim_contracts', 'load_dim_contracts.sql')
44-
load_fact_prices = snowflake_task_factory('load_fact_prices', 'load_fact_prices.sql')
45-
46-
load_stage_raw >> load_stg_dim_markets >> load_dim_markets >> load_stg_dim_contracts >> load_dim_contracts >> load_fact_prices
47-
48-
40+
load_stage_raw = snowflake_task_factory(
41+
"load_stage_raw", "load_stage_raw.sql", timestamp_param=True
42+
)
43+
load_stg_dim_markets = snowflake_task_factory(
44+
"load_stg_dim_markets", "load_stg_dim_markets.sql"
45+
)
46+
load_dim_markets = snowflake_task_factory(
47+
"load_dim_markets", "load_dim_markets.sql"
48+
)
49+
load_stg_dim_contracts = snowflake_task_factory(
50+
"load_stg_dim_contracts", "load_stg_dim_contracts.sql"
51+
)
52+
load_dim_contracts = snowflake_task_factory(
53+
"load_dim_contracts", "load_dim_contracts.sql"
54+
)
55+
load_fact_prices = snowflake_task_factory(
56+
"load_fact_prices", "load_fact_prices.sql"
57+
)
4958

50-
59+
(
60+
load_stage_raw
61+
>> load_stg_dim_markets
62+
>> load_dim_markets
63+
>> load_stg_dim_contracts
64+
>> load_dim_contracts
65+
>> load_fact_prices
66+
)

lambda_fetch/lambda_function.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
logger.setLevel(logging.INFO)
1010

1111
predictit = PredictitAPI()
12-
s3_client = boto3.client('s3')
12+
s3_client = boto3.client("s3")
13+
1314

1415
def lambda_function(filename: str) -> None:
1516
"""
@@ -18,15 +19,15 @@ def lambda_function(filename: str) -> None:
1819
excution_timestamp (str): The execution timestamp for the data
1920
"""
2021
# Poll the PredictIt API
21-
logger.info('Polling PredictIt API market data now')
22+
logger.info("Polling PredictIt API market data now")
2223
data = predictit.poll_market_data()
23-
logger.info('Successfully polled API')
24-
logger.info('Storing to S3 now')
25-
bucket = os.getenv('S3_BUCKET')
24+
logger.info("Successfully polled API")
25+
logger.info("Storing to S3 now")
26+
bucket = os.getenv("S3_BUCKET")
2627
if not bucket:
2728
raise ValueError("S3_BUCKET environment variable is not set")
2829
predictit.store_to_s3(data, bucket=bucket, filename=filename)
29-
logging.info('Successfully stored data to S3')
30+
logging.info("Successfully stored data to S3")
3031

3132

3233
def lambda_handler(event, context) -> Optional[dict]:
@@ -40,14 +41,14 @@ def lambda_handler(event, context) -> Optional[dict]:
4041
Dict status message
4142
"""
4243
try:
43-
filename = event.get('filename')
44+
filename = event.get("filename")
4445
if not filename:
4546
raise ValueError("Filename must be provided in the event data")
4647
lambda_function(filename)
4748
return {
48-
'StatusCode': 200,
49-
'message': 'PredictAPI data succcessfully polled and stored to S3'
49+
"StatusCode": 200,
50+
"message": "PredictAPI data succcessfully polled and stored to S3",
5051
}
5152
except Exception as e:
52-
logger.error(f'Error occurred: {e}')
53-
raise
53+
logger.error(f"Error occurred: {e}")
54+
raise

lambda_validate/lambda_function.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
logger = logging.getLogger()
1010
logger.setLevel(logging.INFO)
1111

12-
s3_client = boto3.client('s3')
12+
s3_client = boto3.client("s3")
13+
1314

1415
def lambda_function(execution_timestamp: str):
1516
"""
@@ -19,32 +20,32 @@ def lambda_function(execution_timestamp: str):
1920
execution_timestamp (str): The execution timestamp for the data
2021
"""
2122
# Validate the PredictIt API data
22-
logger.info('Validating PredictIt API data now')
23-
bucket = os.getenv('S3_BUCKET')
23+
logger.info("Validating PredictIt API data now")
24+
bucket = os.getenv("S3_BUCKET")
2425
if not bucket:
2526
raise ValueError("S3_BUCKET environment variable is not set")
26-
source_key = f'predictit/stage/market_data_{execution_timestamp}.json'
27-
destination_key = f'predictit/raw_data/market_data_{execution_timestamp}.json'
27+
source_key = f"predictit/stage/market_data_{execution_timestamp}.json"
28+
destination_key = f"predictit/raw_data/market_data_{execution_timestamp}.json"
2829
# Load the data from S3
2930
s3_object = s3_client.get_object(Bucket=bucket, Key=source_key)
30-
data = json.loads(s3_object['Body'].read())
31-
31+
data = json.loads(s3_object["Body"].read())
32+
3233
try:
3334
# Validate the data
3435
PredictitResponse(**data)
35-
logger.info('Successfully validated data')
36+
logger.info("Successfully validated data")
3637
except Exception as e:
37-
logger.error(f'Error occurred during data validation: {e}')
38+
logger.error(f"Error occurred during data validation: {e}")
3839
raise
3940

4041
# copy to raw data and delete stage data
4142
s3_client.copy_object(
4243
Bucket=bucket,
43-
CopySource={'Bucket': bucket, 'Key': source_key},
44-
Key=destination_key
44+
CopySource={"Bucket": bucket, "Key": source_key},
45+
Key=destination_key,
4546
)
4647
s3_client.delete_object(Bucket=bucket, Key=source_key)
47-
48+
4849

4950
def lambda_handler(event, context) -> Optional[dict]:
5051
"""
@@ -57,14 +58,11 @@ def lambda_handler(event, context) -> Optional[dict]:
5758
Dict status message
5859
"""
5960
try:
60-
execution_timestamp = event.get('execution_timestamp')
61+
execution_timestamp = event.get("execution_timestamp")
6162
if not execution_timestamp:
6263
raise ValueError("Execution timestamp must be provided in the event data")
6364
lambda_function(execution_timestamp)
64-
return {
65-
'StatusCode': 200,
66-
'message': 'PredictAPI data successfully validated'
67-
}
65+
return {"StatusCode": 200, "message": "PredictAPI data successfully validated"}
6866
except Exception as e:
69-
logger.error(f'Error occurred: {e}')
67+
logger.error(f"Error occurred: {e}")
7068
raise

src/api.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
logger = logging.getLogger(__name__)
1111
logger.setLevel(logging.INFO)
1212

13-
class PredictitAPI:
1413

14+
class PredictitAPI:
1515
def __init__(self, base_url="https://www.predictit.org/api/marketdata"):
1616
self.base_url = base_url
17-
self.s3_client = boto3.client('s3')
17+
self.s3_client = boto3.client("s3")
1818

1919
def poll_market_data(self, market_id: Optional[str] = None) -> Optional[dict]:
2020
"""
@@ -41,19 +41,21 @@ def poll_market_data(self, market_id: Optional[str] = None) -> Optional[dict]:
4141
except Exception as e:
4242
logging.error(f"An error occurred: {e}")
4343

44-
def store_to_s3(self, data: dict, bucket: Optional[str] = None, filename: Optional[str] = None):
44+
def store_to_s3(
45+
self, data: dict, bucket: Optional[str] = None, filename: Optional[str] = None
46+
):
4547
if not filename:
46-
timestamp = datetime.datetime.utcnow().strftime('%Y-%m-%dT%H-%M-%S')
47-
filename = f'market_data_{timestamp}.json'
48-
key = f'predictit/stage/{filename}'
48+
timestamp = datetime.datetime.utcnow().strftime("%Y-%m-%dT%H-%M-%S")
49+
filename = f"market_data_{timestamp}.json"
50+
key = f"predictit/stage/{filename}"
4951
try:
5052
self.s3_client.put_object(
5153
Bucket=bucket,
5254
Key=key,
5355
Body=json.dumps(data),
54-
ContentType='application/json'
56+
ContentType="application/json",
5557
)
56-
logging.info(f'Uploaded {key} to S3 bucket {bucket}')
58+
logging.info(f"Uploaded {key} to S3 bucket {bucket}")
5759
except ClientError as e:
58-
logging.error(f'Failed to upload to bucket: {e}')
59-
raise
60+
logging.error(f"Failed to upload to bucket: {e}")
61+
raise

0 commit comments

Comments
 (0)