Skip to content

Commit 66c6332

Browse files
authored
Merge pull request #3 from datakind/sftp-integration
fix: linting errors
2 parents 3d91f3b + 1033581 commit 66c6332

File tree

7 files changed

+570
-52
lines changed

7 files changed

+570
-52
lines changed

.github/workflows/type-check.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
python-version: "3.10"
1616
- name: Get changed files
1717
id: changed-files
18-
uses: step-security/changed-files@45
18+
uses: step-security/changed-files@v45
1919
with:
2020
files: |
2121
src/**/*.py

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ dependencies = [
2121
"strenum>=0.4.15",
2222
"tomli~=2.0; python_version<'3.11'",
2323
"jsonpickle>=4.0.1",
24+
"requests>=2.0.0",
25+
"types-requests",
26+
"types-paramiko",
27+
"pandas"
2428
]
2529

2630
[project.urls]

src/worker/authn.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,28 @@
22

33
from fastapi.security import (
44
OAuth2PasswordBearer,
5-
OAuth2PasswordRequestForm,
65
APIKeyHeader,
7-
APIKeyQuery,
86
)
97
from pydantic import BaseModel
108
from datetime import timedelta, datetime, timezone
119
from .config import env_vars
1210
from typing import Annotated
13-
from fastapi import Depends, HTTPException, status
11+
from fastapi import Depends, HTTPException, status, Security
1412
from jwt.exceptions import InvalidTokenError
1513

1614
oauth2_scheme = OAuth2PasswordBearer(
1715
tokenUrl="token",
1816
)
1917

18+
api_key_header = APIKeyHeader(name="X-API-KEY", scheme_name="api-key", auto_error=False)
19+
api_key_inst_header = APIKeyHeader(
20+
name="INST", scheme_name="api-inst", auto_error=False
21+
)
22+
# The following is for use by the frontend enduser only.
23+
api_key_enduser_header = APIKeyHeader(
24+
name="ENDUSER", scheme_name="api-enduser", auto_error=False
25+
)
26+
2027

2128
class Token(BaseModel):
2229
access_token: str
@@ -27,7 +34,30 @@ class TokenData(BaseModel):
2734
username: str | None = None
2835

2936

30-
def check_creds(username: str, password: str):
37+
def get_api_key(
38+
api_key_header: str = Security(api_key_header),
39+
api_key_inst_header: str = Security(api_key_inst_header),
40+
api_key_enduser_header: str = Security(api_key_enduser_header),
41+
) -> tuple:
42+
"""Retrieve the api key and enduser header key if present.
43+
44+
Args:
45+
api_key_header: The API key passed in the HTTP header.
46+
47+
Returns:
48+
A tuple with the api key and enduser header if present. Authentication happens elsewhere.
49+
Raises:
50+
HTTPException: If the API key is invalid or missing.
51+
"""
52+
if api_key_header:
53+
return (api_key_header, api_key_inst_header, api_key_enduser_header)
54+
raise HTTPException(
55+
status_code=status.HTTP_401_UNAUTHORIZED,
56+
detail="Invalid or missing API Key",
57+
)
58+
59+
60+
def check_creds(username: str, password: str) -> bool:
3161
if username == env_vars["USERNAME"] and password == env_vars["PASSWORD"]:
3262
return True
3363
raise HTTPException(
@@ -36,13 +66,13 @@ def check_creds(username: str, password: str):
3666
)
3767

3868

39-
def create_access_token(data: dict, expires_delta: timedelta | None = None):
69+
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
4070
to_encode = data.copy()
4171
if expires_delta:
4272
expire = datetime.now(timezone.utc) + expires_delta
4373
else:
4474
expire = datetime.now(timezone.utc) + timedelta(
45-
minutes=env_vars["ACCESS_TOKEN_EXPIRE_MINUTES"]
75+
minutes=float(env_vars["ACCESS_TOKEN_EXPIRE_MINUTES"])
4676
)
4777
to_encode.update({"exp": expire})
4878
encoded_jwt = jwt.encode(

src/worker/main.py

Lines changed: 87 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,27 @@
22

33
import logging
44
from typing import Any, Annotated
5-
from fastapi import FastAPI, Depends, HTTPException, status
5+
from fastapi import FastAPI, Depends, HTTPException, status, Security
66
from fastapi.responses import FileResponse
77

88
from pydantic import BaseModel
9-
from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer
9+
from fastapi.security import OAuth2PasswordRequestForm
1010
from .utilities import (
1111
get_sftp_bucket_name,
1212
StorageControl,
13+
split_csv_and_generate_signed_urls,
14+
fetch_institution_ids,
1315
)
1416
from .config import sftp_vars, env_vars, startup_env_vars
15-
from .authn import Token, get_current_username, check_creds, create_access_token
16-
from datetime import timedelta, datetime, timezone
17+
from .authn import (
18+
Token,
19+
get_current_username,
20+
check_creds,
21+
create_access_token,
22+
get_api_key,
23+
)
24+
from datetime import timedelta
25+
import os
1726

1827
# Set the logging
1928
logging.basicConfig(format="%(asctime)s [%(levelname)s]: %(message)s")
@@ -41,8 +50,9 @@ class PdpPullRequest(BaseModel):
4150
class PdpPullResponse(BaseModel):
4251
"""Fields for the PDP pull response."""
4352

44-
pdp_inst_generated: list[int]
45-
pdp_inst_not_found: list[int]
53+
sftp_files: list[dict]
54+
pdp_inst_generated: list[str]
55+
pdp_inst_not_found: list[str]
4656

4757

4858
@app.on_event("startup")
@@ -84,30 +94,86 @@ async def login_for_access_token(
8494
return Token(access_token=access_token, token_type="bearer")
8595

8696

87-
def sftp_helper(
88-
storage_control: StorageControl, sftp_source_filename: str, dest_filename: str
89-
):
90-
storage_control.copy_from_sftp_to_gcs(
91-
sftp_vars["SFTP_HOST"],
92-
sftp_vars["SFTP_PORT"],
93-
sftp_vars["SFTP_USER"],
94-
sftp_vars["SFTP_PASSWORD"],
95-
sftp_source_filename,
96-
get_sftp_bucket_name(env_vars["ENV"]),
97-
dest_filename,
98-
)
97+
def sftp_helper(storage_control: StorageControl, sftp_source_filenames: list) -> list:
98+
"""
99+
For each source file in sftp_source_filenames, copies the file from the SFTP
100+
server to GCS. The destination filename is automatically generated by prefixing
101+
the base name of the source file with "processed_".
102+
103+
Args:
104+
storage_control (StorageControl): An instance with a method `copy_from_sftp_to_gcs`.
105+
sftp_source_filenames (list): A list of file paths on the SFTP server.
106+
"""
107+
num_files = len(sftp_source_filenames)
108+
logger.info(f"Starting sftp_helper for {num_files} file(s).")
109+
all_blobs = []
110+
for sftp_source_filename in sftp_source_filenames:
111+
sftp_source_filename = sftp_source_filename["path"]
112+
if (
113+
sftp_source_filename
114+
== "./receive/AO1600pdp_AO1600_AR_DEIDENTIFIED_STUDYID_20250228030226.csv"
115+
):
116+
logger.debug(f"Processing source file: {sftp_source_filename}")
117+
118+
# Extract the base filename.
119+
base_filename = os.path.basename(sftp_source_filename)
120+
dest_filename = f"{base_filename}"
121+
logger.debug(f"Destination filename will be: {dest_filename}")
122+
123+
try:
124+
storage_control.copy_from_sftp_to_gcs(
125+
sftp_vars["SFTP_HOST"],
126+
22,
127+
sftp_vars["SFTP_USER"],
128+
sftp_vars["SFTP_PASSWORD"],
129+
sftp_source_filename,
130+
get_sftp_bucket_name(env_vars["ENV"]),
131+
dest_filename,
132+
)
133+
all_blobs.append(dest_filename)
134+
logger.info(
135+
f"Successfully processed '{sftp_source_filename}' as '{dest_filename}'."
136+
)
137+
return all_blobs
138+
except Exception as e:
139+
logger.error(
140+
f"Error processing '{sftp_source_filename}': {e}", exc_info=True
141+
)
142+
return all_blobs
99143

100144

101145
@app.post("/execute-pdp-pull", response_model=PdpPullResponse)
102146
def execute_pdp_pull(
103147
req: PdpPullRequest,
104148
current_username: Annotated[str, Depends(get_current_username)],
105149
storage_control: Annotated[StorageControl, Depends(StorageControl)],
150+
api_key_enduser_tuple: str = Security(get_api_key),
106151
) -> Any:
107152
"""Performs the PDP pull of the file."""
108153
storage_control.create_bucket_if_not_exists(get_sftp_bucket_name(env_vars["ENV"]))
109-
sftp_helper(storage_control, "sftp_file.csv", "write_out_file.csv")
154+
files = storage_control.list_sftp_files(
155+
sftp_vars["SFTP_HOST"], 22, sftp_vars["SFTP_USER"], sftp_vars["SFTP_PASSWORD"]
156+
)
157+
all_blobs = sftp_helper(storage_control, files)
158+
valid_pdp_ids = []
159+
invalid_ids = []
160+
161+
for blobs in all_blobs:
162+
signed_urls = split_csv_and_generate_signed_urls(
163+
bucket_name=get_sftp_bucket_name(env_vars["ENV"]), source_blob_name=blobs
164+
)
165+
166+
temp_valid_pdp_ids, temp_invalid_ids = fetch_institution_ids(
167+
pdp_ids=list(signed_urls.keys()),
168+
backend_api_key=next(
169+
key for key in api_key_enduser_tuple if key is not None
170+
),
171+
)
172+
valid_pdp_ids.append(temp_valid_pdp_ids)
173+
invalid_ids.append(temp_invalid_ids)
174+
110175
return {
111-
"pdp_inst_generated": [],
112-
"pdp_inst_not_found": [],
176+
"sftp_files": files,
177+
"pdp_inst_generated": valid_pdp_ids,
178+
"pdp_inst_not_found": invalid_ids,
113179
}

src/worker/main_test.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
"""Test file for the main.py file and constituent API functions."""
22

33
import pytest
4-
import json
4+
from typing import Any
55

66
from fastapi.testclient import TestClient
77
from .main import app
8-
import uuid
9-
from .authn import get_current_username
8+
from .authn import get_current_username, get_api_key
109
from unittest import mock
1110
from .utilities import StorageControl
1211

@@ -21,22 +20,26 @@ def get_current_username_override():
2120
def storage_control_override():
2221
return MOCK_STORAGE
2322

23+
def get_api_key_override():
24+
return ("valid_api_key", "end_user")
25+
2426
app.dependency_overrides[StorageControl] = storage_control_override
2527

2628
app.dependency_overrides[get_current_username] = get_current_username_override
2729

30+
app.dependency_overrides[get_api_key] = get_api_key_override
2831
client = TestClient(app, root_path="/workers/api/v1")
2932
yield client
3033
app.dependency_overrides.clear()
3134

3235

33-
def test_get_root(client: TestClient):
36+
def test_get_root(client: TestClient) -> Any:
3437
"""Test GET /."""
3538
response = client.get("/")
3639
assert response.status_code == 200
3740

3841

39-
def test_retrieve_token(client: TestClient):
42+
def test_retrieve_token(client: TestClient) -> Any:
4043
"""Test POST /token."""
4144
response = client.post(
4245
"/token",
@@ -46,14 +49,23 @@ def test_retrieve_token(client: TestClient):
4649
assert response.status_code == 200
4750

4851

49-
def test_execute_pdp_pull(client: TestClient):
52+
def test_execute_pdp_pull(client: TestClient) -> Any:
5053
"""Test POST /execute-pdp-pull."""
51-
MOCK_STORAGE.copy_from_sftp_to_gcs.return_value = None
54+
MOCK_STORAGE.copy_from_sftp_to_gcs.side_effect = (
55+
lambda filename: f"processed_{filename}"
56+
)
5257
MOCK_STORAGE.create_bucket_if_not_exists.return_value = None
58+
MOCK_STORAGE.list_sftp_files.return_value = [
59+
{"path": "file1.csv"},
60+
{"path": "file2.csv"},
61+
]
5362

5463
response = client.post("/execute-pdp-pull", json={"placeholder": "val"})
64+
65+
# Verify the response status and content.
5566
assert response.status_code == 200
5667
assert response.json() == {
68+
"sftp_files": [{"path": "file1.csv"}, {"path": "file2.csv"}],
5769
"pdp_inst_generated": [],
5870
"pdp_inst_not_found": [],
5971
}

0 commit comments

Comments
 (0)