Skip to content

Commit ec094ba

Browse files
dittopsclaude
andcommitted
fix(tests): update budpipeline tests to match current implementation
- Update deployment action tests to use endpoint_id instead of deployment_id - Update test mocks to properly simulate successful execution - Update action count expectation from 20 to 17 (actions were consolidated) - Add credential_ref to expected ParamType values Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent f9ff8a7 commit ec094ba

File tree

3 files changed

+56
-28
lines changed

3 files changed

+56
-28
lines changed

services/budpipeline/tests/actions/deployment/test_deployment_actions.py

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -62,33 +62,49 @@ def test_validate_params_missing_endpoint_name(self) -> None:
6262
assert any("endpoint_name" in e for e in errors)
6363

6464
def test_validate_params_valid(self) -> None:
65-
"""Test validation passes with required params."""
65+
"""Test validation passes with required params for local model."""
6666
executor = DeploymentCreateExecutor()
6767
errors = executor.validate_params(
6868
{
6969
"model_id": "model-123",
7070
"project_id": "proj-123",
7171
"endpoint_name": "test-endpoint",
72+
"cluster_id": "cluster-123",
73+
"hardware_mode": "shared", # shared mode doesn't require SLO targets
7274
}
7375
)
7476
assert len(errors) == 0
7577

76-
def test_validate_params_cluster_optional(self) -> None:
77-
"""Test that cluster_id is optional (for cloud models)."""
78+
def test_validate_params_cloud_model(self) -> None:
79+
"""Test validation passes for cloud model with credential_id."""
7880
executor = DeploymentCreateExecutor()
7981
errors = executor.validate_params(
8082
{
8183
"model_id": "model-123",
8284
"project_id": "proj-123",
8385
"endpoint_name": "test-endpoint",
84-
# cluster_id intentionally omitted
86+
"credential_id": "cred-123", # cloud model uses credential instead of cluster
8587
}
8688
)
8789
assert len(errors) == 0
8890

91+
def test_validate_params_requires_cluster_or_credential(self) -> None:
92+
"""Test that either cluster_id or credential_id is required."""
93+
executor = DeploymentCreateExecutor()
94+
errors = executor.validate_params(
95+
{
96+
"model_id": "model-123",
97+
"project_id": "proj-123",
98+
"endpoint_name": "test-endpoint",
99+
# Neither cluster_id nor credential_id provided
100+
}
101+
)
102+
assert len(errors) == 1
103+
assert "cluster_id" in errors[0] or "credential_id" in errors[0]
104+
89105

90106
class TestDeploymentDeleteAction:
91-
"""Tests for DeploymentDeleteAction (placeholder)."""
107+
"""Tests for DeploymentDeleteAction."""
92108

93109
def test_meta_attributes(self) -> None:
94110
"""Test action metadata attributes."""
@@ -99,28 +115,34 @@ def test_meta_attributes(self) -> None:
99115
assert meta.execution_mode.value == "event_driven"
100116
assert meta.idempotent is True
101117

102-
def test_validate_params_missing_deployment_id(self) -> None:
103-
"""Test validation catches missing deployment_id."""
118+
def test_validate_params_missing_endpoint_id(self) -> None:
119+
"""Test validation catches missing endpoint_id."""
104120
executor = DeploymentDeleteExecutor()
105121
errors = executor.validate_params({})
106-
assert any("deployment_id" in e for e in errors)
122+
assert any("endpoint_id" in e for e in errors)
107123

108124
def test_validate_params_valid(self) -> None:
109125
"""Test validation passes with required params."""
110126
executor = DeploymentDeleteExecutor()
111-
errors = executor.validate_params({"deployment_id": "deploy-123"})
127+
errors = executor.validate_params({"endpoint_id": "endpoint-123"})
112128
assert len(errors) == 0
113129

114130
@pytest.mark.asyncio
115-
async def test_execute_returns_not_implemented(self) -> None:
116-
"""Test that execute returns not implemented error."""
131+
async def test_execute_success(self) -> None:
132+
"""Test successful delete execution."""
117133
executor = DeploymentDeleteExecutor()
118-
context = make_context(deployment_id="deploy-123")
134+
context = make_context(endpoint_id="endpoint-123")
135+
136+
# Mock the invoke_service method
137+
context.invoke_service = AsyncMock(
138+
return_value={"workflow_id": "workflow-123", "status": "started"}
139+
)
119140

120141
result = await executor.execute(context)
121142

122-
assert result.success is False
123-
assert "not yet implemented" in result.error.lower()
143+
assert result.success is True
144+
assert result.awaiting_event is True
145+
assert result.outputs["endpoint_id"] == "endpoint-123"
124146

125147

126148
class TestDeploymentScaleAction:
@@ -359,7 +381,7 @@ async def test_execute_scales_to_zero(self) -> None:
359381

360382

361383
class TestDeploymentRateLimitAction:
362-
"""Tests for DeploymentRateLimitAction (placeholder)."""
384+
"""Tests for DeploymentRateLimitAction."""
363385

364386
def test_meta_attributes(self) -> None:
365387
"""Test action metadata attributes."""
@@ -370,36 +392,41 @@ def test_meta_attributes(self) -> None:
370392
assert meta.execution_mode.value == "sync"
371393
assert meta.idempotent is True
372394

373-
def test_validate_params_missing_deployment_id(self) -> None:
374-
"""Test validation catches missing deployment_id."""
395+
def test_validate_params_missing_endpoint_id(self) -> None:
396+
"""Test validation catches missing endpoint_id."""
375397
executor = DeploymentRateLimitExecutor()
376398
errors = executor.validate_params({"requests_per_second": 100})
377-
assert any("deployment_id" in e for e in errors)
399+
assert any("endpoint_id" in e for e in errors)
378400

379401
def test_validate_params_invalid_rps(self) -> None:
380402
"""Test validation catches non-positive requests_per_second."""
381403
executor = DeploymentRateLimitExecutor()
382-
errors = executor.validate_params({"deployment_id": "deploy-123", "requests_per_second": 0})
404+
errors = executor.validate_params({"endpoint_id": "endpoint-123", "requests_per_second": 0})
383405
assert any("requests_per_second" in e for e in errors)
384406

385407
def test_validate_params_valid(self) -> None:
386408
"""Test validation passes with valid params."""
387409
executor = DeploymentRateLimitExecutor()
388410
errors = executor.validate_params(
389-
{"deployment_id": "deploy-123", "requests_per_second": 100}
411+
{"endpoint_id": "endpoint-123", "requests_per_second": 100}
390412
)
391413
assert len(errors) == 0
392414

393415
@pytest.mark.asyncio
394-
async def test_execute_returns_not_implemented(self) -> None:
395-
"""Test that execute returns not implemented error."""
416+
async def test_execute_success(self) -> None:
417+
"""Test successful rate limit configuration."""
396418
executor = DeploymentRateLimitExecutor()
397-
context = make_context(deployment_id="deploy-123", requests_per_second=100)
419+
context = make_context(endpoint_id="endpoint-123", requests_per_second=100)
420+
421+
# Mock the invoke_service method
422+
context.invoke_service = AsyncMock(
423+
return_value={"status": "success", "rate_limit_config": {"requests_per_second": 100}}
424+
)
398425

399426
result = await executor.execute(context)
400427

401-
assert result.success is False
402-
assert "not yet implemented" in result.error.lower()
428+
assert result.success is True
429+
assert result.outputs["endpoint_id"] == "endpoint-123"
403430

404431

405432
class TestDeploymentActionsRegistration:

services/budpipeline/tests/actions/test_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ async def test_list_actions_returns_all_actions(self) -> None:
3434
assert "categories" in data
3535
assert "total" in data
3636

37-
# Should have at least the 20 built-in actions
38-
assert data["total"] >= 20
39-
assert len(data["actions"]) >= 20
37+
# Should have at least the 17 built-in actions
38+
assert data["total"] >= 17
39+
assert len(data["actions"]) >= 17
4040

4141
@pytest.mark.asyncio
4242
async def test_list_actions_includes_categories(self) -> None:

services/budpipeline/tests/actions/test_meta.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def test_all_param_types_exist(self) -> None:
3535
"project_ref",
3636
"endpoint_ref",
3737
"provider_ref",
38+
"credential_ref",
3839
]
3940
actual_types = [pt.value for pt in ParamType]
4041
assert sorted(actual_types) == sorted(expected_types)

0 commit comments

Comments
 (0)