Skip to content

Commit 1704002

Browse files
authored
Fix Spark version detection for automatic Deequ maven coordinate definition (#114)
* fix: revert get spark version behavior * fix: better spark version extraction
1 parent 7ec9f6f commit 1704002

File tree

3 files changed

+33
-13
lines changed

3 files changed

+33
-13
lines changed

.github/workflows/base.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
on:
22
push:
33
branches:
4-
- "*"
4+
- "*"
55

66
jobs:
77
test:
@@ -23,7 +23,7 @@ jobs:
2323
name: Setup Java 11
2424
if: startsWith(matrix.PYSPARK_VERSION, '3')
2525
with:
26-
java-version: '11'
26+
java-version: "11"
2727

2828
- name: Running tests with pyspark==${{matrix.PYSPARK_VERSION}}
2929
env:

pydeequ/configs.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# -*- coding: utf-8 -*-
22
from functools import lru_cache
3-
import subprocess
3+
import os
44
import re
55

6+
67
SPARK_TO_DEEQU_COORD_MAPPING = {
78
"3.2": "com.amazon.deequ:deequ:2.0.1-spark-3.2",
89
"3.1": "com.amazon.deequ:deequ:2.0.0-spark-3.1",
@@ -11,17 +12,21 @@
1112
}
1213

1314

15+
def _extract_major_minor_versions(full_version: str):
16+
major_minor_pattern = re.compile(r"(\d+\.\d+)\.*")
17+
match = re.match(major_minor_pattern, full_version)
18+
if match:
19+
return match.group(1)
20+
21+
1422
@lru_cache(maxsize=None)
1523
def _get_spark_version() -> str:
16-
# Get version from a subprocess so we don't mess up with existing SparkContexts.
17-
command = [
18-
"python",
19-
"-c",
20-
"from pyspark import SparkContext; print(SparkContext.getOrCreate()._jsc.version())",
21-
]
22-
output = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
23-
spark_version = output.stdout.decode().split("\n")[-2]
24-
return spark_version
24+
try:
25+
spark_version = os.environ["SPARK_VERSION"]
26+
except KeyError:
27+
raise RuntimeError(f"SPARK_VERSION environment variable is required. Supported values are: {SPARK_TO_DEEQU_COORD_MAPPING.keys()}")
28+
29+
return _extract_major_minor_versions(spark_version)
2530

2631

2732
def _get_deequ_maven_config():
@@ -30,7 +35,7 @@ def _get_deequ_maven_config():
3035
return SPARK_TO_DEEQU_COORD_MAPPING[spark_version[:3]]
3136
except KeyError:
3237
raise RuntimeError(
33-
f"Found Incompatible Spark version {spark_version}; Use one of the Supported Spark versions for Deequ: {SPARK_TO_DEEQU_COORD_MAPPING.keys()}"
38+
f"Found incompatible Spark version {spark_version}; Use one of the Supported Spark versions for Deequ: {SPARK_TO_DEEQU_COORD_MAPPING.keys()}"
3439
)
3540

3641

tests/test_config.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import pytest
2+
from pydeequ.configs import _extract_major_minor_versions
3+
4+
5+
@pytest.parametrize(
6+
"full_version, major_minor_version",
7+
[
8+
("3.2.1", "3.2"),
9+
("3.1", "3.1"),
10+
("3.10.3", "3.10"),
11+
("3.10", "3.10")
12+
]
13+
)
14+
def test_extract_major_minor_versions(full_version, major_minor_version):
15+
assert _extract_major_minor_versions(full_version) == major_minor_version

0 commit comments

Comments
 (0)