Skip to content

Commit d0edcf3

Browse files
committed
add api and processor tests
1 parent 3251235 commit d0edcf3

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

tests/test_api.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import pytest
2+
from fastapi.testclient import TestClient
3+
4+
from make_api.app.main import app
5+
6+
client = TestClient(app)
7+
8+
9+
@pytest.fixture
10+
def sample_input():
11+
return {
12+
"passenger_count": 2,
13+
"trip_type": 1,
14+
"congestion_surcharge": 2.5,
15+
"mean_distance": 3.2,
16+
"mean_duration": 7.5,
17+
"rush_hour": 1,
18+
"vendor_id": "VTS",
19+
}
20+
21+
22+
def test_predict_one(sample_input):
23+
"""Test the predict_one endpoint works"""
24+
25+
response = client.post("/predict", json=sample_input)
26+
27+
assert response.status_code == 200
28+
29+
json_response = response.json()
30+
31+
assert "prediction" in json_response
32+
assert isinstance(json_response["prediction"], float)
33+
34+
35+
def test_predict_missing_input():
36+
"""Test the predict_one endpoint fails with missing input"""
37+
38+
incomplete_input = {
39+
"passenger_count": 2,
40+
"trip_type": 1,
41+
"congestion_surcharge": 2.5,
42+
"mean_distance": 3.2,
43+
"rush_hour": 1,
44+
"vendor_id": "VTS",
45+
} # missing mean_duration
46+
47+
response = client.post("/predict", json=incomplete_input)
48+
49+
assert response.status_code == 422

tests/test_data_processor.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from unittest.mock import MagicMock
2+
3+
import pandas as pd
4+
import pytest
5+
6+
from make_data.data_processor import DataProcessor
7+
8+
9+
@pytest.fixture
10+
def sample_dataframe():
11+
data = {
12+
"lpep_pickup_datetime": ["2024-02-08 10:00:00", "2024-02-08 15:00:00"],
13+
"lpep_dropoff_datetime": ["2024-02-08 10:15:00", "2024-02-08 15:30:00"],
14+
"fare_amount": [10, -5],
15+
"PULocationID": [1, 2],
16+
"DOLocationID": [3, 4],
17+
"trip_distance": [2.5, 3.5],
18+
"VendorID": [1, 2],
19+
}
20+
df = pd.DataFrame(data)
21+
return df
22+
23+
24+
@pytest.fixture
25+
def mock_config():
26+
config = MagicMock()
27+
config.num_features = ["fare_amount", "trip_distance"]
28+
config.cat_features = ["vendor_id"]
29+
config.target = ["duration"]
30+
return config
31+
32+
33+
def test_process_data(sample_dataframe, mock_config):
34+
"""Test the process_data method"""
35+
36+
processor = DataProcessor(sample_dataframe, mock_config)
37+
processor.process_data()
38+
39+
assert "duration" in processor.df.columns
40+
assert "vendor_id" in processor.df.columns
41+
assert processor.df["fare_amount"].min() >= 0
42+
assert processor.df["duration"].min() >= 0
43+
44+
45+
def test_split_data(sample_dataframe, mock_config):
46+
"""Test the split_data method"""
47+
48+
processor = DataProcessor(sample_dataframe, mock_config)
49+
processor.process_data()
50+
train, test = processor.split_data()
51+
52+
assert len(train) > 0
53+
assert len(test) > 0
54+
assert len(train) + len(test) == len(processor.df)

0 commit comments

Comments
 (0)