Skip to content

Commit 10f7e8c

Browse files
committed
test refactor
Signed-off-by: Pat O'Connor <[email protected]>
1 parent 0b66a1b commit 10f7e8c

File tree

4 files changed

+107
-107
lines changed

4 files changed

+107
-107
lines changed
Lines changed: 104 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,38 @@
1+
# Copyright 2024 IBM, Red Hat
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
import pytest
216
import yaml
317
from pathlib import Path
4-
from unittest.mock import MagicMock, patch
18+
from unittest.mock import MagicMock
519
from codeflare_sdk.ray.job.job import RayJob, RayJobSpec, RayJobStatus
620

7-
# Path to expected YAML files
821
parent = Path(__file__).resolve().parents[4] # project directory
922
expected_yamls_dir = f"{parent}/tests/test_cluster_yamls/ray"
1023

1124

1225
def test_rayjob_to_dict_minimal():
1326
"""Test RayJob.to_dict() with minimal configuration using YAML comparison"""
1427
spec = RayJobSpec(entrypoint="python test.py")
15-
job = RayJob(
16-
metadata={"name": "test-job", "namespace": "default"},
17-
spec=spec
18-
)
19-
28+
job = RayJob(metadata={"name": "test-job", "namespace": "default"}, spec=spec)
29+
2030
result = job.to_dict()
21-
31+
2232
# Load expected YAML
2333
with open(f"{expected_yamls_dir}/rayjob-minimal.yaml") as f:
2434
expected = yaml.safe_load(f)
25-
35+
2636
assert result == expected
2737

2838

@@ -43,53 +53,53 @@ def test_rayjob_to_dict_full_spec():
4353
message="Job is running successfully",
4454
start_time="2024-01-01T10:00:00Z",
4555
end_time=None,
46-
driver_info={"id": "driver-123", "node_ip_address": "10.0.0.1", "pid": "12345"}
56+
driver_info={"id": "driver-123", "node_ip_address": "10.0.0.1", "pid": "12345"},
4757
)
48-
58+
4959
job = RayJob(
5060
metadata={
5161
"name": "ml-training-job",
5262
"namespace": "ml-jobs",
5363
"labels": {"app": "ml-training", "version": "v1"},
54-
"annotations": {"description": "Machine learning training job"}
64+
"annotations": {"description": "Machine learning training job"},
5565
},
56-
spec=spec
66+
spec=spec,
5767
)
58-
68+
5969
result = job.to_dict()
60-
70+
6171
# Load expected YAML
6272
with open(f"{expected_yamls_dir}/rayjob-full-spec.yaml") as f:
6373
expected = yaml.safe_load(f)
64-
74+
6575
assert result == expected
6676

6777

6878
def test_rayjob_to_dict_with_existing_status():
6979
"""Test RayJob.to_dict() when status is already set using YAML comparison"""
7080
spec = RayJobSpec(entrypoint="python test.py")
71-
81+
7282
# Pre-existing status from Kubernetes
7383
existing_status = {
7484
"jobStatus": "SUCCEEDED",
7585
"jobId": "raysubmit_12345",
7686
"message": "Job completed successfully",
7787
"startTime": "2024-01-01T10:00:00Z",
78-
"endTime": "2024-01-01T10:05:00Z"
88+
"endTime": "2024-01-01T10:05:00Z",
7989
}
80-
90+
8191
job = RayJob(
8292
metadata={"name": "completed-job", "namespace": "default"},
8393
spec=spec,
84-
status=existing_status
94+
status=existing_status,
8595
)
86-
96+
8797
result = job.to_dict()
88-
98+
8999
# Load expected YAML
90100
with open(f"{expected_yamls_dir}/rayjob-existing-status.yaml") as f:
91101
expected = yaml.safe_load(f)
92-
102+
93103
assert result == expected
94104

95105

@@ -102,129 +112,119 @@ def test_rayjob_status_enum_values():
102112
assert RayJobStatus.FAILED == "FAILED"
103113

104114

105-
@patch('codeflare_sdk.common.kubernetes_cluster.auth.config_check')
106-
@patch('codeflare_sdk.common.kubernetes_cluster.auth.get_api_client')
107-
@patch('kubernetes.dynamic.DynamicClient')
108-
def test_rayjob_apply_success(mock_dynamic_client, mock_get_api_client, mock_config_check):
115+
def test_rayjob_apply_success(mocker):
109116
"""Test RayJob.apply() method successful execution"""
110-
# Mock the Kubernetes API components
117+
# Mock Kubernetes configuration and API components
118+
mocker.patch("codeflare_sdk.common.kubernetes_cluster.auth.config_check")
119+
mocker.patch("codeflare_sdk.common.kubernetes_cluster.auth.get_api_client")
120+
mock_dynamic_client = mocker.patch("kubernetes.dynamic.DynamicClient")
121+
111122
mock_api_instance = MagicMock()
112123
mock_dynamic_client.return_value.resources.get.return_value = mock_api_instance
113-
124+
114125
spec = RayJobSpec(entrypoint="python test.py")
115-
job = RayJob(
116-
metadata={"name": "test-job", "namespace": "test-ns"},
117-
spec=spec
118-
)
119-
126+
job = RayJob(metadata={"name": "test-job", "namespace": "test-ns"}, spec=spec)
127+
120128
# Test successful apply
121129
job.apply()
122-
130+
123131
# Verify the API was called correctly
124-
mock_config_check.assert_called_once()
125-
mock_get_api_client.assert_called_once()
126-
mock_dynamic_client.assert_called_once()
127132
mock_api_instance.server_side_apply.assert_called_once()
128-
133+
129134
# Check the server_side_apply call arguments
130135
call_args = mock_api_instance.server_side_apply.call_args
131-
assert call_args[1]['field_manager'] == 'codeflare-sdk'
132-
assert call_args[1]['group'] == 'ray.io'
133-
assert call_args[1]['version'] == 'v1'
134-
assert call_args[1]['namespace'] == 'test-ns'
135-
assert call_args[1]['plural'] == 'rayjobs'
136-
assert call_args[1]['force_conflicts'] == False
137-
136+
assert call_args[1]["field_manager"] == "codeflare-sdk"
137+
assert call_args[1]["group"] == "ray.io"
138+
assert call_args[1]["version"] == "v1"
139+
assert call_args[1]["namespace"] == "test-ns"
140+
assert call_args[1]["plural"] == "rayjobs"
141+
assert call_args[1]["force_conflicts"] == False
142+
138143
# Verify the body contains the expected job structure
139-
body = call_args[1]['body']
140-
assert body['apiVersion'] == 'ray.io/v1'
141-
assert body['kind'] == 'RayJob'
142-
assert body['metadata']['name'] == 'test-job'
143-
assert body['spec']['entrypoint'] == 'python test.py'
144+
body = call_args[1]["body"]
145+
assert body["apiVersion"] == "ray.io/v1"
146+
assert body["kind"] == "RayJob"
147+
assert body["metadata"]["name"] == "test-job"
148+
assert body["spec"]["entrypoint"] == "python test.py"
144149

145150

146-
@patch('codeflare_sdk.common.kubernetes_cluster.auth.config_check')
147-
@patch('codeflare_sdk.common.kubernetes_cluster.auth.get_api_client')
148-
@patch('kubernetes.dynamic.DynamicClient')
149-
def test_rayjob_apply_with_force(mock_dynamic_client, mock_get_api_client, mock_config_check):
151+
def test_rayjob_apply_with_force(mocker):
150152
"""Test RayJob.apply() method with force=True"""
153+
# Mock Kubernetes configuration and API components
154+
mocker.patch("codeflare_sdk.common.kubernetes_cluster.auth.config_check")
155+
mocker.patch("codeflare_sdk.common.kubernetes_cluster.auth.get_api_client")
156+
mock_dynamic_client = mocker.patch("kubernetes.dynamic.DynamicClient")
157+
151158
mock_api_instance = MagicMock()
152159
mock_dynamic_client.return_value.resources.get.return_value = mock_api_instance
153-
160+
154161
spec = RayJobSpec(entrypoint="python test.py")
155-
job = RayJob(
156-
metadata={"name": "test-job", "namespace": "default"},
157-
spec=spec
158-
)
159-
162+
job = RayJob(metadata={"name": "test-job", "namespace": "default"}, spec=spec)
163+
160164
# Test apply with force=True
161165
job.apply(force=True)
162-
166+
163167
# Verify force_conflicts was set to True
164168
call_args = mock_api_instance.server_side_apply.call_args
165-
assert call_args[1]['force_conflicts'] == True
169+
assert call_args[1]["force_conflicts"] == True
166170

167171

168-
@patch('codeflare_sdk.common.kubernetes_cluster.auth.config_check')
169-
@patch('codeflare_sdk.common.kubernetes_cluster.auth.get_api_client')
170-
@patch('kubernetes.dynamic.DynamicClient')
171-
def test_rayjob_apply_dynamic_client_error(mock_dynamic_client, mock_get_api_client, mock_config_check):
172+
def test_rayjob_apply_dynamic_client_error(mocker):
172173
"""Test RayJob.apply() method with DynamicClient initialization error"""
174+
# Mock Kubernetes configuration
175+
mocker.patch("codeflare_sdk.common.kubernetes_cluster.auth.config_check")
176+
mocker.patch("codeflare_sdk.common.kubernetes_cluster.auth.get_api_client")
177+
173178
# Mock DynamicClient to raise AttributeError
179+
mock_dynamic_client = mocker.patch("kubernetes.dynamic.DynamicClient")
174180
mock_dynamic_client.side_effect = AttributeError("Failed to connect")
175-
181+
176182
spec = RayJobSpec(entrypoint="python test.py")
177-
job = RayJob(
178-
metadata={"name": "test-job", "namespace": "default"},
179-
spec=spec
180-
)
181-
183+
job = RayJob(metadata={"name": "test-job", "namespace": "default"}, spec=spec)
184+
182185
# Test that RuntimeError is raised
183186
with pytest.raises(RuntimeError, match="Failed to initialize DynamicClient"):
184187
job.apply()
185188

186189

187-
@patch('codeflare_sdk.common.kubernetes_cluster.auth.config_check')
188-
@patch('codeflare_sdk.common.kubernetes_cluster.auth.get_api_client')
189-
@patch('kubernetes.dynamic.DynamicClient')
190-
@patch('codeflare_sdk.common._kube_api_error_handling')
191-
def test_rayjob_apply_api_error(mock_error_handling, mock_dynamic_client, mock_get_api_client, mock_config_check):
190+
def test_rayjob_apply_api_error(mocker):
192191
"""Test RayJob.apply() method with Kubernetes API error"""
192+
# Mock Kubernetes configuration and API components
193+
mocker.patch("codeflare_sdk.common.kubernetes_cluster.auth.config_check")
194+
mocker.patch("codeflare_sdk.common.kubernetes_cluster.auth.get_api_client")
195+
mock_dynamic_client = mocker.patch("kubernetes.dynamic.DynamicClient")
196+
mock_error_handling = mocker.patch("codeflare_sdk.common._kube_api_error_handling")
197+
193198
# Mock the API to raise an exception
194199
mock_api_instance = MagicMock()
195200
mock_api_instance.server_side_apply.side_effect = Exception("API Error")
196201
mock_dynamic_client.return_value.resources.get.return_value = mock_api_instance
197-
202+
198203
spec = RayJobSpec(entrypoint="python test.py")
199-
job = RayJob(
200-
metadata={"name": "test-job", "namespace": "default"},
201-
spec=spec
202-
)
203-
204+
job = RayJob(metadata={"name": "test-job", "namespace": "default"}, spec=spec)
205+
204206
# Test that error handling is called
205207
job.apply()
206-
208+
207209
# Verify error handling was called
208210
mock_error_handling.assert_called_once()
209211

210212

211-
def test_rayjob_default_namespace_in_apply():
213+
def test_rayjob_default_namespace_in_apply(mocker):
212214
"""Test that RayJob.apply() uses 'default' namespace when not specified in metadata"""
213-
with patch('codeflare_sdk.common.kubernetes_cluster.auth.config_check'), \
214-
patch('codeflare_sdk.common.kubernetes_cluster.auth.get_api_client'), \
215-
patch('kubernetes.dynamic.DynamicClient') as mock_dynamic_client:
216-
217-
mock_api_instance = MagicMock()
218-
mock_dynamic_client.return_value.resources.get.return_value = mock_api_instance
219-
220-
spec = RayJobSpec(entrypoint="python test.py")
221-
job = RayJob(
222-
metadata={"name": "test-job"}, # No namespace specified
223-
spec=spec
224-
)
225-
226-
job.apply()
227-
228-
# Verify default namespace was used
229-
call_args = mock_api_instance.server_side_apply.call_args
230-
assert call_args[1]['namespace'] == 'default'
215+
# Mock Kubernetes configuration and API components
216+
mocker.patch("codeflare_sdk.common.kubernetes_cluster.auth.config_check")
217+
mocker.patch("codeflare_sdk.common.kubernetes_cluster.auth.get_api_client")
218+
mock_dynamic_client = mocker.patch("kubernetes.dynamic.DynamicClient")
219+
220+
mock_api_instance = MagicMock()
221+
mock_dynamic_client.return_value.resources.get.return_value = mock_api_instance
222+
223+
spec = RayJobSpec(entrypoint="python test.py")
224+
job = RayJob(metadata={"name": "test-job"}, spec=spec) # No namespace specified
225+
226+
job.apply()
227+
228+
# Verify default namespace was used
229+
call_args = mock_api_instance.server_side_apply.call_args
230+
assert call_args[1]["namespace"] == "default"

tests/test_cluster_yamls/ray/rayjob-existing-status.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ status:
1919
jobId: raysubmit_12345
2020
message: Job completed successfully
2121
startTime: '2024-01-01T10:00:00Z'
22-
endTime: '2024-01-01T10:05:00Z'
22+
endTime: '2024-01-01T10:05:00Z'

tests/test_cluster_yamls/ray/rayjob-full-spec.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,4 @@ status:
3434
driver_info:
3535
id: driver-123
3636
node_ip_address: 10.0.0.1
37-
pid: '12345'
37+
pid: '12345'

tests/test_cluster_yamls/ray/rayjob-minimal.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ status:
1919
message: null
2020
start_time: null
2121
end_time: null
22-
driver_info: null
22+
driver_info: null

0 commit comments

Comments
 (0)