Skip to content

Commit a1a4eef

Browse files
committed
Add job to run unit tests
1 parent 75218eb commit a1a4eef

File tree

4 files changed

+53
-1
lines changed

4 files changed

+53
-1
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
.databricks/
22
.venv/
3+
.pytest_cache/
34
*.pyc
45
__pycache__/
5-
.pytest_cache/
66
dist/
77
build/
88
covid_analysis.egg-info/

.vscode/launch.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
// Use IntelliSense to learn about possible attributes.
3+
// Hover to view descriptions of existing attributes.
4+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5+
"version": "0.2.0",
6+
"configurations": [
7+
{
8+
"type": "databricks",
9+
"request": "launch",
10+
"name": "Unit Tests (on Databricks)",
11+
"program": "${workspaceFolder}/jobs/pytest_databricks.py",
12+
"args": ["./tests", "-p", "no:cacheprovider"],
13+
"env": {}
14+
}
15+
]
16+
}

jobs/pytest_databricks.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import pytest
2+
import os
3+
import sys
4+
5+
6+
def main():
7+
# Run all tests in the repository root.
8+
repo_root = os.path.dirname(os.path.dirname(__file__))
9+
os.chdir(repo_root)
10+
11+
# Skip writing pyc files on a readonly filesystem.
12+
sys.dont_write_bytecode = True
13+
14+
_ = pytest.main(sys.argv[1:])
15+
16+
17+
if __name__ == "__main__":
18+
main()

tests/spark_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from pyspark.sql import SparkSession
2+
import pytest
3+
4+
5+
@pytest.fixture
6+
def spark() -> SparkSession:
7+
"""
8+
Create a spark session. Unit tests don't have access to the spark global
9+
"""
10+
return SparkSession.builder.getOrCreate()
11+
12+
13+
def test_spark(spark):
14+
"""
15+
Example test that needs to run on the cluster to work
16+
"""
17+
data = spark.sql("select 1").collect()
18+
assert data[0][0] == 1

0 commit comments

Comments
 (0)