11# SPDX-FileCopyrightText: 2024-2025 MTS PJSC
22# SPDX-License-Identifier: Apache-2.0
3- from collections .abc import Sequence
3+ from collections .abc import Collection
44from dataclasses import dataclass
55from datetime import datetime , timezone
66from typing import Literal
@@ -84,7 +84,7 @@ async def create_or_update_bulk(self, inputs: list[InputDTO]) -> None:
8484
8585 async def list_by_operation_ids (
8686 self ,
87- operation_ids : Sequence [UUID ],
87+ operation_ids : Collection [UUID ],
8888 granularity : Literal ["JOB" , "RUN" , "OPERATION" ],
8989 ) -> list [InputRow ]:
9090 if not operation_ids :
@@ -98,14 +98,14 @@ async def list_by_operation_ids(
9898 where = [
9999 Input .created_at >= min_created_at ,
100100 Input .created_at <= max_created_at ,
101- Input .operation_id == any_ (operation_ids ), # type: ignore[arg-type]
101+ Input .operation_id == any_ (list ( operation_ids ) ), # type: ignore[arg-type]
102102 ]
103103
104104 return await self ._get_inputs (where , granularity )
105105
106106 async def list_by_run_ids (
107107 self ,
108- run_ids : Sequence [UUID ],
108+ run_ids : Collection [UUID ],
109109 since : datetime ,
110110 until : datetime | None ,
111111 granularity : Literal ["JOB" , "RUN" , "OPERATION" ],
@@ -118,7 +118,7 @@ async def list_by_run_ids(
118118
119119 where = [
120120 Input .created_at >= min_created_at ,
121- Input .run_id == any_ (run_ids ), # type: ignore[arg-type]
121+ Input .run_id == any_ (list ( run_ids ) ), # type: ignore[arg-type]
122122 ]
123123 if until :
124124 where .append (Input .created_at <= until )
@@ -127,7 +127,7 @@ async def list_by_run_ids(
127127
128128 async def list_by_job_ids (
129129 self ,
130- job_ids : Sequence [int ],
130+ job_ids : Collection [int ],
131131 since : datetime ,
132132 until : datetime | None ,
133133 granularity : Literal ["JOB" , "RUN" , "OPERATION" ],
@@ -137,7 +137,7 @@ async def list_by_job_ids(
137137
138138 where = [
139139 Input .created_at >= since ,
140- Input .job_id == any_ (job_ids ), # type: ignore[arg-type]
140+ Input .job_id == any_ (list ( job_ids ) ), # type: ignore[arg-type]
141141 ]
142142 if until :
143143 where .append (Input .created_at <= until )
@@ -146,7 +146,7 @@ async def list_by_job_ids(
146146
147147 async def list_by_dataset_ids (
148148 self ,
149- dataset_ids : Sequence [int ],
149+ dataset_ids : Collection [int ],
150150 since : datetime ,
151151 until : datetime | None ,
152152 granularity : Literal ["JOB" , "RUN" , "OPERATION" ],
@@ -156,7 +156,7 @@ async def list_by_dataset_ids(
156156
157157 where = [
158158 Input .created_at >= since ,
159- Input .dataset_id == any_ (dataset_ids ), # type: ignore[arg-type]
159+ Input .dataset_id == any_ (list ( dataset_ids ) ), # type: ignore[arg-type]
160160 ]
161161 if until :
162162 where .append (Input .created_at <= until )
@@ -165,7 +165,7 @@ async def list_by_dataset_ids(
165165
166166 async def _get_inputs (
167167 self ,
168- where : list [ColumnElement ],
168+ where : Collection [ColumnElement ],
169169 granularity : Literal ["JOB" , "RUN" , "OPERATION" ],
170170 ) -> list [InputRow ]:
171171 if granularity == "OPERATION" :
@@ -277,7 +277,7 @@ async def _get_inputs(
277277 )
278278 return results
279279
280- async def get_stats_by_operation_ids (self , operation_ids : Sequence [UUID ]) -> dict [UUID , Row ]:
280+ async def get_stats_by_operation_ids (self , operation_ids : Collection [UUID ]) -> dict [UUID , Row ]:
281281 if not operation_ids :
282282 return {}
283283
@@ -298,15 +298,15 @@ async def get_stats_by_operation_ids(self, operation_ids: Sequence[UUID]) -> dic
298298 .where (
299299 Input .created_at >= min_created_at ,
300300 Input .created_at <= max_created_at ,
301- Input .operation_id == any_ (operation_ids ), # type: ignore[arg-type]
301+ Input .operation_id == any_ (list ( operation_ids ) ), # type: ignore[arg-type]
302302 )
303303 .group_by (Input .operation_id )
304304 )
305305
306306 query_result = await self ._session .execute (query )
307307 return {row .operation_id : row for row in query_result .all ()}
308308
309- async def get_stats_by_run_ids (self , run_ids : Sequence [UUID ]) -> dict [UUID , Row ]:
309+ async def get_stats_by_run_ids (self , run_ids : Collection [UUID ]) -> dict [UUID , Row ]:
310310 if not run_ids :
311311 return {}
312312
@@ -322,7 +322,7 @@ async def get_stats_by_run_ids(self, run_ids: Sequence[UUID]) -> dict[UUID, Row]
322322 )
323323 .where (
324324 Input .created_at >= min_created_at ,
325- Input .run_id == any_ (run_ids ), # type: ignore[arg-type]
325+ Input .run_id == any_ (list ( run_ids ) ), # type: ignore[arg-type]
326326 )
327327 .group_by (Input .run_id )
328328 )
0 commit comments