Skip to content

Commit d3d42c0

Browse files
Make ucx pylsp plugin configurable (#2280)
## Changes Make LSP linter plugin configurable with cluster information. This config can be provided either in a file or by a client and its provisioning is handled by pylsp infrastructure. Spark Connect linter is now applied only to UC Shared clusters, as Single-User clusters are running in Spark Classic mode. ### Tests <!-- How is this tested? Please see the checklist below and also describe any other relevant tests --> - [x] manually tested - [x] added unit tests - [ ] added integration tests - [ ] verified on staging environment (screenshot attached)
1 parent 46f4239 commit d3d42c0

File tree

15 files changed

+240
-35
lines changed

15 files changed

+240
-35
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ pylsp = [
6060
runtime = "databricks.labs.ucx.runtime:main"
6161

6262
[project.entry-points.pylsp]
63-
plugin = "databricks.labs.ucx.source_code.lsp_plugin"
63+
pylsp_ucx = "databricks.labs.ucx.source_code.lsp_plugin"
6464

6565
[project.urls]
6666
Issues = "https://github.com/databricks/ucx/issues"

src/databricks/labs/ucx/hive_metastore/migration_status.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,17 @@ class MigrationStatus:
2626
def destination(self):
2727
return f"{self.dst_catalog}.{self.dst_schema}.{self.dst_table}".lower()
2828

29+
@classmethod
30+
def from_json(cls, raw: dict[str, str]) -> "MigrationStatus":
31+
return cls(
32+
src_schema=raw['src_schema'],
33+
src_table=raw['src_table'],
34+
dst_catalog=raw.get('dst_catalog', None),
35+
dst_schema=raw.get('dst_schema', None),
36+
dst_table=raw.get('dst_table', None),
37+
update_ts=raw.get('update_ts', None),
38+
)
39+
2940

3041
@dataclass(frozen=True)
3142
class TableView:

src/databricks/labs/ucx/source_code/base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,19 @@ def from_json(cls, json: dict) -> CurrentSessionState:
187187
catalog=json.get('catalog', DEFAULT_CATALOG),
188188
spark_conf=json.get('spark_conf', None),
189189
named_parameters=json.get('named_parameters', None),
190-
data_security_mode=json.get('data_security_mode', None),
190+
data_security_mode=cls.parse_security_mode(json.get('data_security_mode', None)),
191191
is_serverless=json.get('is_serverless', False),
192192
dbr_version=tuple(json['dbr_version']) if 'dbr_version' in json else None,
193193
)
194194

195+
@staticmethod
196+
def parse_security_mode(mode_str: str | None) -> compute.DataSecurityMode | None:
197+
try:
198+
return compute.DataSecurityMode(mode_str) if mode_str else None
199+
except ValueError:
200+
logger.warning(f'Unknown data_security_mode {mode_str}')
201+
return None
202+
195203

196204
class SequentialLinter(Linter):
197205
def __init__(self, linters: list[Linter]):

src/databricks/labs/ucx/source_code/linters/spark_connect.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
PythonLinter,
1010
CurrentSessionState,
1111
)
12+
from databricks.sdk.service.compute import DataSecurityMode
13+
1214
from databricks.labs.ucx.source_code.linters.python_ast import Tree, TreeHelper
1315

1416

@@ -238,6 +240,10 @@ def lint(self, node: NodeNG) -> Iterator[Advice]:
238240

239241
class SparkConnectLinter(PythonLinter):
240242
def __init__(self, session_state: CurrentSessionState):
243+
if session_state.data_security_mode != DataSecurityMode.USER_ISOLATION:
244+
self._matchers = []
245+
return
246+
241247
self._matchers = [
242248
JvmAccessMatcher(session_state=session_state),
243249
RDDApiMatcher(session_state=session_state),
Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,44 @@
1+
import logging
2+
from packaging import version
3+
14
from pylsp import hookimpl # type: ignore
2-
from pylsp.workspace import Document, Workspace # type: ignore
5+
from pylsp.config.config import Config # type: ignore
6+
from pylsp.workspace import Document # type: ignore
37
from databricks.sdk.service.workspace import Language
48

9+
from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex, MigrationStatus
10+
from databricks.labs.ucx.source_code.base import CurrentSessionState
511
from databricks.labs.ucx.source_code.linters.context import LinterContext
612
from databricks.labs.ucx.source_code.lsp import Diagnostic
713

814

15+
logger = logging.getLogger(__name__)
16+
17+
918
@hookimpl
10-
def pylsp_lint(workspace: Workspace, document: Document) -> list[dict]: # pylint: disable=unused-argument
11-
# TODO: initialize migration index and session state from config / env variables
12-
languages = LinterContext(index=None, session_state=None)
19+
def pylsp_lint(config: Config, document: Document) -> list[dict]:
20+
cfg = config.plugin_settings('pylsp_ucx', document_path=document.uri)
21+
22+
migration_index = MigrationIndex([MigrationStatus.from_json(st) for st in cfg.get('migration_index', [])])
23+
24+
session_state = CurrentSessionState(
25+
data_security_mode=CurrentSessionState.parse_security_mode(cfg.get('dataSecurityMode', None)),
26+
dbr_version=parse_dbr_version(cfg.get('dbrVersion', None)),
27+
is_serverless=bool(cfg.get('isServerless', False)),
28+
)
29+
languages = LinterContext(index=migration_index, session_state=session_state)
1330
analyser = languages.linter(Language.PYTHON)
1431
code = document.source
1532
diagnostics = [Diagnostic.from_advice(_) for _ in analyser.lint(code)]
1633
return [d.as_dict() for d in diagnostics]
34+
35+
36+
def parse_dbr_version(version_str: str | None) -> tuple[int, int] | None:
37+
if not version_str:
38+
return None
39+
try:
40+
release_version = version.parse(version_str).release
41+
return release_version[0], release_version[1]
42+
except version.InvalidVersion:
43+
logger.warning(f'Incorrect DBR version string: {version_str}')
44+
return None

tests/unit/source_code/linters/test_spark_connect.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
from itertools import chain
22

3+
import pytest
4+
5+
36
from databricks.labs.ucx.source_code.base import Failure, CurrentSessionState
47
from databricks.labs.ucx.source_code.linters.python_ast import Tree
58
from databricks.labs.ucx.source_code.linters.spark_connect import LoggingMatcher, SparkConnectLinter
9+
from databricks.sdk.service.compute import DataSecurityMode
10+
11+
12+
@pytest.fixture
13+
def session_state() -> CurrentSessionState:
14+
return CurrentSessionState(data_security_mode=DataSecurityMode.USER_ISOLATION)
615

716

8-
def test_jvm_access_match_shared():
9-
linter = SparkConnectLinter(CurrentSessionState())
17+
def test_jvm_access_match_shared(session_state):
18+
linter = SparkConnectLinter(session_state)
1019
code = """
1120
spark.range(10).collect()
1221
spark._jspark._jvm.com.my.custom.Name()
@@ -25,8 +34,9 @@ def test_jvm_access_match_shared():
2534
assert actual == expected
2635

2736

28-
def test_jvm_access_match_serverless():
29-
linter = SparkConnectLinter(CurrentSessionState(is_serverless=True))
37+
def test_jvm_access_match_serverless(session_state):
38+
session_state.is_serverless = True
39+
linter = SparkConnectLinter(session_state)
3040
code = """
3141
spark.range(10).collect()
3242
spark._jspark._jvm.com.my.custom.Name()
@@ -46,8 +56,8 @@ def test_jvm_access_match_serverless():
4656
assert actual == expected
4757

4858

49-
def test_rdd_context_match_shared():
50-
linter = SparkConnectLinter(CurrentSessionState())
59+
def test_rdd_context_match_shared(session_state):
60+
linter = SparkConnectLinter(session_state)
5161
code = """
5262
rdd1 = sc.parallelize([1, 2, 3])
5363
rdd2 = spark.createDataFrame(sc.emptyRDD(), schema)
@@ -90,8 +100,9 @@ def test_rdd_context_match_shared():
90100
assert actual == expected
91101

92102

93-
def test_rdd_context_match_serverless():
94-
linter = SparkConnectLinter(CurrentSessionState(is_serverless=True))
103+
def test_rdd_context_match_serverless(session_state):
104+
session_state.is_serverless = True
105+
linter = SparkConnectLinter(session_state)
95106
code = """
96107
rdd1 = sc.parallelize([1, 2, 3])
97108
rdd2 = spark.createDataFrame(sc.emptyRDD(), schema)
@@ -132,8 +143,8 @@ def test_rdd_context_match_serverless():
132143
] == list(linter.lint(code))
133144

134145

135-
def test_rdd_map_partitions():
136-
linter = SparkConnectLinter(CurrentSessionState())
146+
def test_rdd_map_partitions(session_state):
147+
linter = SparkConnectLinter(session_state)
137148
code = """
138149
df = spark.createDataFrame([])
139150
df.rdd.mapPartitions(myUdf)
@@ -152,8 +163,8 @@ def test_rdd_map_partitions():
152163
assert actual == expected
153164

154165

155-
def test_conf_shared():
156-
linter = SparkConnectLinter(CurrentSessionState())
166+
def test_conf_shared(session_state):
167+
linter = SparkConnectLinter(session_state)
157168
code = """df.sparkContext.getConf().get('spark.my.conf')"""
158169
assert [
159170
Failure(
@@ -167,8 +178,9 @@ def test_conf_shared():
167178
] == list(linter.lint(code))
168179

169180

170-
def test_conf_serverless():
171-
linter = SparkConnectLinter(CurrentSessionState(is_serverless=True))
181+
def test_conf_serverless(session_state):
182+
session_state.is_serverless = True
183+
linter = SparkConnectLinter(session_state)
172184
code = """sc._conf().get('spark.my.conf')"""
173185
expected = [
174186
Failure(
@@ -184,8 +196,8 @@ def test_conf_serverless():
184196
assert actual == expected
185197

186198

187-
def test_logging_shared():
188-
logging_matcher = LoggingMatcher(CurrentSessionState())
199+
def test_logging_shared(session_state):
200+
logging_matcher = LoggingMatcher(session_state)
189201
code = """
190202
sc.setLogLevel("INFO")
191203
setLogLevel("WARN")
@@ -225,8 +237,9 @@ def test_logging_shared():
225237
] == list(chain.from_iterable([logging_matcher.lint(node) for node in Tree.parse(code).walk()]))
226238

227239

228-
def test_logging_serverless():
229-
logging_matcher = LoggingMatcher(CurrentSessionState(is_serverless=True))
240+
def test_logging_serverless(session_state):
241+
session_state.is_serverless = True
242+
logging_matcher = LoggingMatcher(session_state)
230243
code = """
231244
sc.setLogLevel("INFO")
232245
log4jLogger = sc._jvm.org.apache.log4j
@@ -255,7 +268,7 @@ def test_logging_serverless():
255268

256269

257270
def test_valid_code():
258-
linter = SparkConnectLinter(CurrentSessionState())
271+
linter = SparkConnectLinter(CurrentSessionState(data_security_mode=DataSecurityMode.USER_ISOLATION))
259272
code = """
260273
df = spark.range(10)
261274
df.collect()

tests/unit/source_code/samples/functional/spark-connect/catalog-api_13_3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# ucx[session-state] {"dbr_version": [13, 3]}
1+
# ucx[session-state] {"dbr_version": [13, 3], "data_security_mode": "USER_ISOLATION"}
22
# ucx[catalog-api-in-shared-clusters:+1:0:+1:13] spark.catalog functions require DBR 14.3 LTS or above on UC Shared Clusters
33
spark.catalog.tableExists("table")
44
# ucx[catalog-api-in-shared-clusters:+1:0:+1:13] spark.catalog functions require DBR 14.3 LTS or above on UC Shared Clusters

tests/unit/source_code/samples/functional/spark-connect/catalog-api_14_3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# ucx[session-state] {"dbr_version": [14, 3]}
1+
# ucx[session-state] {"dbr_version": [14, 3], "data_security_mode": "USER_ISOLATION"}
22
spark.catalog.tableExists("table")
33
spark.catalog.listDatabases()
44

tests/unit/source_code/samples/functional/spark-connect/command-context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# ucx[session-state] {"data_security_mode": "USER_ISOLATION"}
12
# ucx[to-json-in-shared-clusters:+1:6:+1:80] toJson() is not available on UC Shared Clusters. Use toSafeJson() on DBR 13.3 LTS or above to get a subset of command context information.
23
print(dbutils.notebook.entry_point.getDbutils().notebook().getContext().toJson())
34
dbutils.notebook.entry_point.getDbutils().notebook().getContext().toSafeJson()

tests/unit/source_code/samples/functional/spark-connect/jvm-access.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# ucx[session-state] {"data_security_mode": "USER_ISOLATION"}
12
spark.range(10).collect()
23
# ucx[jvm-access-in-shared-clusters:+1:0:+1:18] Cannot access Spark Driver JVM on UC Shared Clusters
34
spark._jspark._jvm.com.my.custom.Name()

0 commit comments

Comments
 (0)