Skip to content

Commit 2a23b77

Browse files
committed
[chore] fix ty on tests/cascade
1 parent 98a6796 commit 2a23b77

File tree

5 files changed

+29
-22
lines changed

5 files changed

+29
-22
lines changed

tests/cascade/benchmarks/image_processing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88

99
import io
1010

11-
import matplotlib.pyplot as plt
11+
import matplotlib.pyplot as plt# ty: ignore[unresolved-import]
1212
import numpy as np
1313
from PIL import Image
1414

15-
from cascade.cascade import Cascade
16-
from cascade.fluent import Fluent, Payload # type: ignore
17-
from cascade.visualise import visualise
15+
from cascade.cascade import Cascade # ty: ignore[unresolved-import]
16+
from cascade.fluent import Fluent, Payload # ty: ignore[unresolved-import]
17+
from cascade.visualise import visualise # ty: ignore[unresolved-import]
1818

1919

2020
def mandelbrot(c, max_iter):

tests/cascade/executor/test_executor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def test_func(x: np.ndarray) -> np.ndarray:
135135
callback(
136136
m1,
137137
TaskSequence(
138-
worker=w0, tasks=["source", "sink"], publish={sink_o}, extra_env={}
138+
worker=w0, tasks=["source", "sink"], publish={sink_o}, extra_env=[]
139139
),
140140
)
141141
# NOTE we need to expect source_o dataset too, because of no finegraining for host-wide and worker-only
@@ -191,7 +191,7 @@ def test_func(x: np.ndarray) -> np.ndarray:
191191
logger.debug(f"about to remove received message {m}")
192192
expected.remove(m)
193193
callback(
194-
m1, TaskSequence(worker=w0, tasks=["sink"], publish={sink_o}, extra_env={})
194+
m1, TaskSequence(worker=w0, tasks=["sink"], publish={sink_o}, extra_env=[])
195195
)
196196
expected = [
197197
DatasetPublished(w0, ds=sink_o, transmit_idx=None),

tests/cascade/executor/test_runner.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def _send_command(
5656
raise ValueError(f"double allocate on {key}")
5757
opened_buffers.add(key)
5858
purging_tracker.add(key)
59-
return shm_api.AllocateResponse(shmid=key, error=None)
59+
return shm_api.AllocateResponse(shmid=key, error="")
6060
elif isinstance(comm, shm_api.CloseCallback):
6161
opened_buffers.remove(key2shmid(comm.key))
6262
else:
@@ -79,7 +79,7 @@ def _send_command(
7979
worker=worker,
8080
tasks=[],
8181
publish=set(),
82-
extra_env={},
82+
extra_env=[],
8383
)
8484
emptyRc = entrypoint.RunnerContext(
8585
workerId=worker,
@@ -115,7 +115,7 @@ def test_func(x):
115115
worker=worker,
116116
tasks=["t2"],
117117
publish={t2ds},
118-
extra_env={},
118+
extra_env=[],
119119
)
120120
oneTaskJob = JobInstance(tasks={"t2": t2}, edges=[])
121121
oneTaskRc = entrypoint.RunnerContext(
@@ -153,7 +153,7 @@ def test_func(x):
153153
worker=worker,
154154
tasks=["t3a", "t3b"],
155155
publish={t3o},
156-
extra_env={},
156+
extra_env=[],
157157
)
158158
twoTaskJob = JobInstance(
159159
tasks={"t3a": t3a, "t3b": t3b},
@@ -216,7 +216,7 @@ def gen_func():
216216
worker=worker,
217217
tasks=["t4g"] + [f"t4c{i}" for i in range(N)],
218218
publish=set(t4pOutputs),
219-
extra_env={},
219+
extra_env=[],
220220
)
221221
t4Job = JobInstance(
222222
tasks={**{"t4g": t4g}, **{f"t4c{i}": t4c for i in range(N)}},

tests/cascade/gateway/test_run.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def test_job():
8888

8989
submit_job_req = api.SubmitJobRequest(job=js)
9090
submit_job_res = client.request_response(submit_job_req, url)
91+
assert isinstance(submit_job_res, api.SubmitJobResponse)
9192
job_id = submit_job_res.job_id
9293
assert submit_job_res.error is None
9394
assert job_id is not None
@@ -96,11 +97,10 @@ def test_job():
9697
job_progress_req = api.JobProgressRequest(job_ids=[job_id])
9798
while tries < tries_limit:
9899
job_progress_res = client.request_response(job_progress_req, url)
100+
assert isinstance(job_progress_res, api.JobProgressResponse)
99101
assert job_progress_res.error is None
100-
is_computed = job_progress_res.progresses[job_id].pct == "100.00"
101-
is_datasets = (
102-
ji.jobInstance.ext_outputs[0] in job_progress_res.datasets[job_id]
103-
)
102+
is_computed = job_progress_res.progresses[job_id].pct == "100.00" # ty: ignore[possibly-missing-attribute]
103+
is_datasets = ji.jobInstance.ext_outputs[0] in job_progress_res.datasets[job_id]
104104
if is_computed and is_datasets:
105105
break
106106
else:
@@ -114,15 +114,17 @@ def test_job():
114114
job_id=job_id, dataset_id=ji.jobInstance.ext_outputs[0]
115115
)
116116
result_retrieval_res = client.request_response(result_retrieval_req, url)
117+
assert isinstance(result_retrieval_res, api.ResultRetrievalResponse)
117118
assert result_retrieval_res.error is None
118119
assert result_retrieval_res.result is not None
119-
deser = api.decoded_result(result_retrieval_res, ji)
120+
deser = api.decoded_result(result_retrieval_res, ji.jobInstance)
120121
assert deser == job_func(init_value)
121122

122123
result_deletion_req = api.ResultDeletionRequest(
123124
datasets={job_id: [ji.jobInstance.ext_outputs[0]]}
124125
)
125126
result_deletion_res = client.request_response(result_deletion_req, url)
127+
assert isinstance(result_deletion_res, api.ResultDeletionResponse)
126128
assert result_deletion_res.error is None
127129

128130
# fail job
@@ -137,6 +139,7 @@ def test_job():
137139

138140
submit_job_req = api.SubmitJobRequest(job=js)
139141
submit_job_res = client.request_response(submit_job_req, url)
142+
assert isinstance(submit_job_res, api.SubmitJobResponse)
140143
job_id = submit_job_res.job_id
141144
assert submit_job_res.error is None
142145
assert job_id is not None
@@ -145,9 +148,10 @@ def test_job():
145148
job_progress_req = api.JobProgressRequest(job_ids=[job_id])
146149
while tries < tries_limit:
147150
job_progress_res = client.request_response(job_progress_req, url)
151+
assert isinstance(job_progress_res, api.JobProgressResponse)
148152
assert job_progress_res.error is None
149-
assert job_progress_res.progresses[job_id].pct != "100.00"
150-
if job_progress_res.progresses[job_id].failure is not None:
153+
assert job_progress_res.progresses[job_id].pct != "100.00" # ty: ignore[possibly-missing-attribute]
154+
if job_progress_res.progresses[job_id].failure is not None: # ty: ignore[possibly-missing-attribute]
151155
break
152156
else:
153157
tries += 1
@@ -169,6 +173,8 @@ def test_job():
169173
req = api.SubmitJobRequest(job=js)
170174
res1 = client.request_response(req, url)
171175
res2 = client.request_response(req, url)
176+
assert isinstance(res1, api.SubmitJobResponse)
177+
assert isinstance(res2, api.SubmitJobResponse)
172178
assert res1.error is None
173179
assert res2.error is None
174180
assert res1.job_id is not None
@@ -179,9 +185,10 @@ def test_job():
179185
job_progress_req = api.JobProgressRequest(job_ids=job_ids)
180186
while tries < tries_limit:
181187
job_progress_res = client.request_response(job_progress_req, url)
188+
assert isinstance(job_progress_res, api.JobProgressResponse)
182189
assert job_progress_res.error is None
183190
is_computed = (
184-
lambda job_id: job_progress_res.progresses[job_id].pct == "100.00"
191+
lambda job_id: job_progress_res.progresses[job_id].pct == "100.00" # ty: ignore[possibly-missing-attribute]
185192
)
186193
if all(is_computed(job_id) for job_id in job_ids):
187194
break
@@ -195,7 +202,7 @@ def test_job():
195202
# gw shutdown
196203
shutdown_req = api.ShutdownRequest()
197204
shutdown_res = client.request_response(shutdown_req, url, 3000)
198-
assert shutdown_res.error is None
205+
assert hasattr(shutdown_res, 'error') and shutdown_res.error is None
199206
gw.join(5)
200207
assert gw.exitcode == 0
201208

tests/cascade/low/test_builders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_func(x: int, y: str) -> int:
1919
"invalid static input for task: x needs int, got <class 'str'>",
2020
"invalid static input for task: y needs str, got <class 'int'>",
2121
]
22-
assert sorted(job.e) == sorted(expected)
22+
assert sorted(job.e) == sorted(expected) # ty: ignore[invalid-argument-type]
2323

2424
task_good = TaskBuilder.from_callable(test_func).with_values(x=1, y="yes")
2525
_ = JobBuilder().with_node("task", task_good).build().get_or_raise()
@@ -35,7 +35,7 @@ def test_func(x: int, y: str) -> int:
3535
expected = [
3636
"edge connects two incompatible nodes: source=source.0 sink_task='sink' sink_input_kw='y' sink_input_ps=None", # noqa: E501
3737
]
38-
assert sorted(job.e) == sorted(expected)
38+
assert sorted(job.e) == sorted(expected) # ty: ignore[invalid-argument-type]
3939

4040
sink_good = TaskBuilder.from_callable(test_func).with_values(y="yes")
4141
_ = (

0 commit comments

Comments
 (0)