Skip to content

Commit c62d861

Browse files
authored
Accept JSON policy object in cluster policy commands (#557)
1 parent 81c26a8 commit c62d861

File tree

3 files changed

+92
-2
lines changed

3 files changed

+92
-2
lines changed

databricks_cli/cluster_policies/api.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,34 @@
2020
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2121
# See the License for the specific language governing permissions and
2222
# limitations under the License.
23+
from json import dumps as json_dumps
24+
2325
from databricks_cli.sdk import PolicyService
2426

2527

2628
class ClusterPolicyApi(object):
2729
def __init__(self, api_client):
2830
self.client = PolicyService(api_client)
2931

32+
@staticmethod
33+
def format_policy_for_api(policy):
34+
if isinstance(policy.get("definition"), dict):
35+
policy["definition"] = json_dumps(policy["definition"])
36+
return policy
37+
3038
def create_cluster_policy(self, json):
31-
return self.client.client.perform_query('POST', '/policies/clusters/create', data=json)
39+
return self.client.client.perform_query(
40+
"POST",
41+
"/policies/clusters/create",
42+
data=ClusterPolicyApi.format_policy_for_api(json),
43+
)
3244

3345
def edit_cluster_policy(self, json):
34-
return self.client.client.perform_query('POST', '/policies/clusters/edit', data=json)
46+
return self.client.client.perform_query(
47+
"POST",
48+
"/policies/clusters/edit",
49+
data=ClusterPolicyApi.format_policy_for_api(json),
50+
)
3551

3652
def delete_cluster_policy(self, policy_id):
3753
return self.client.delete_policy(policy_id)

tests/cluster_policies/test_api.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Databricks CLI
2+
# Copyright 2017 Databricks, Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"), except
5+
# that the use of services to which certain application programming
6+
# interfaces (each, an "API") connect requires that the user first obtain
7+
# a license for the use of the APIs from Databricks, Inc. ("Databricks"),
8+
# by creating an account at www.databricks.com and agreeing to either (a)
9+
# the Community Edition Terms of Service, (b) the Databricks Terms of
10+
# Service, or (c) another written agreement between Licensee and Databricks
11+
# for the use of the APIs.
12+
#
13+
# You may not use this file except in compliance with the License.
14+
# You may obtain a copy of the License at
15+
#
16+
# http://www.apache.org/licenses/LICENSE-2.0
17+
#
18+
# Unless required by applicable law or agreed to in writing, software
19+
# distributed under the License is distributed on an "AS IS" BASIS,
20+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21+
# See the License for the specific language governing permissions and
22+
# limitations under the License.
23+
24+
import mock
25+
import pytest
26+
27+
from databricks_cli.cluster_policies.api import ClusterPolicyApi
28+
29+
30+
@pytest.mark.parametrize(
31+
"policy, expected",
32+
[
33+
({"definition": "foo"}, {"definition": "foo"}),
34+
({"definition": {"foo": "bar"}}, {"definition": '{"foo": "bar"}'}),
35+
],
36+
)
37+
def test_format_policy_for_api(policy, expected):
38+
result = ClusterPolicyApi.format_policy_for_api(policy)
39+
assert result == expected
40+
41+
42+
@pytest.mark.parametrize(
43+
"fct_name, method, action",
44+
[
45+
("create_cluster_policy", "POST", "create"),
46+
("edit_cluster_policy", "POST", "edit"),
47+
],
48+
)
49+
@mock.patch(
50+
"databricks_cli.cluster_policies.api.ClusterPolicyApi.format_policy_for_api"
51+
)
52+
def test_create_and_edit_cluster_policy(
53+
mock_format_policy_for_api, fct_name, method, action, fixture_cluster_policies_api
54+
):
55+
mock_policy = mock.Mock()
56+
getattr(fixture_cluster_policies_api, fct_name)(mock_policy)
57+
mock_format_policy_for_api.assert_called_once_with(mock_policy)
58+
fixture_cluster_policies_api.client.client.perform_query.assert_called_once_with(
59+
method,
60+
"/policies/clusters/{}".format(action),
61+
data=mock_format_policy_for_api.return_value,
62+
)

tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,12 @@
2222
# limitations under the License.
2323
import shutil
2424
import tempfile
25+
26+
import mock
2527
import pytest
2628

2729
import databricks_cli.configure.provider as provider
30+
from databricks_cli.cluster_policies.api import ClusterPolicyApi
2831

2932

3033
@pytest.fixture(autouse=True)
@@ -33,3 +36,12 @@ def mock_conf_dir():
3336
provider._home = path
3437
yield
3538
shutil.rmtree(path)
39+
40+
41+
@pytest.fixture()
42+
def fixture_cluster_policies_api():
43+
with mock.patch(
44+
"databricks_cli.cluster_policies.api.PolicyService"
45+
) as service_mock:
46+
service_mock.return_value = mock.MagicMock()
47+
yield ClusterPolicyApi(None)

0 commit comments

Comments
 (0)