Skip to content

Commit 970f3e4

Browse files
authored
Implement subset splitting (#4965)
1 parent 763886a commit 970f3e4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1592
-255
lines changed

application/backend/app/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
import click
1111

12-
from app.core.models.task_type import TaskType
1312
from app.db import MigrationManager, get_db_session
1413
from app.db.schema import DatasetItemDB, LabelDB, ModelRevisionDB, PipelineDB, ProjectDB, SinkDB, SourceDB
14+
from app.models import TaskType
1515
from app.schemas import DisconnectedSinkConfig, DisconnectedSourceConfig, OutputFormat, SinkType, SourceType
1616
from app.schemas.model import TrainingStatus
1717
from app.schemas.pipeline import FixedRateDataCollectionPolicy

application/backend/app/core/jobs/exec/process_run.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from multiprocessing.connection import Connection
2424
from multiprocessing.context import SpawnProcess
2525
from multiprocessing.synchronize import Event
26-
from pathlib import Path
2726

2827
from app.core.jobs.models import Done, ExecutionEvent, Failed, Job, JobType, Started
2928
from app.core.run import ExecutionContext, RunnableFactory, Runner
@@ -39,14 +38,12 @@ class ProcessRun:
3938
4039
Args:
4140
ctx (mp.context.SpawnContext): Multiprocessing context for process & IPC.
42-
data_dir (Path): Directory for job data.
4341
runnable_factory (RunnableFactory): Factory to create runnable job instances.
4442
job (Job): Job specification.
4543
"""
4644

47-
def __init__(self, ctx: mp.context.SpawnContext, data_dir: Path, runnable_factory: RunnableFactory, job: Job):
45+
def __init__(self, ctx: mp.context.SpawnContext, runnable_factory: RunnableFactory, job: Job):
4846
self._ctx = ctx
49-
self._data_dir = data_dir
5047
self._runnable_factory = runnable_factory
5148
self._job = job
5249
self._parent, self._child = ctx.Pipe(duplex=False)
@@ -58,7 +55,6 @@ def start(self) -> "ProcessRun":
5855
target=_entrypoint,
5956
args=(
6057
self._runnable_factory,
61-
self._data_dir,
6258
self._job.job_type,
6359
self._job.params.model_dump_json(),
6460
self._child,
@@ -118,7 +114,7 @@ async def stop(self, graceful_timeout: float = 6.0, term_timeout: float = 3.0, k
118114

119115

120116
def _entrypoint(
121-
get_runnable: RunnableFactory, data_dir: Path, job_type: str, payload: str, conn: Connection, cancel_event: Event
117+
get_runnable: RunnableFactory, job_type: str, payload: str, conn: Connection, cancel_event: Event
122118
) -> None:
123119
"""
124120
Entrypoint for the child process.
@@ -127,7 +123,6 @@ def _entrypoint(
127123
128124
Args:
129125
get_runnable (RunnableFactory): Factory to create runnable job instance.
130-
data_dir (Path): Directory for job data.
131126
job_type (str): Type of job to execute.
132127
payload (str): Serialized job parameters.
133128
conn (Connection): IPC connection to parent process.
@@ -150,7 +145,7 @@ def heartbeat():
150145

151146
try:
152147
conn.send(Started())
153-
runnable.run(ExecutionContext(payload=payload, data_dir=data_dir, report=report, heartbeat=heartbeat))
148+
runnable.run(ExecutionContext(payload=payload, report=report, heartbeat=heartbeat))
154149
conn.send(Done())
155150
except CancelledExc:
156151
conn.send(Cancelled())
@@ -166,17 +161,15 @@ class ProcessRunnerFactory:
166161
Factory for creating process-based job runners.
167162
168163
Args:
169-
data_dir (Path): Directory for job data.
170164
runnable_factory (RunnableFactory): Factory to create runnable job instances.
171165
172166
Methods:
173167
for_job(job: Job) -> Runner[Job, ExecutionEvent]: Create a ProcessRun instance for the given job.
174168
"""
175169

176-
def __init__(self, data_dir: Path, runnable_factory: RunnableFactory) -> None:
170+
def __init__(self, runnable_factory: RunnableFactory) -> None:
177171
# consider using native context for python 3.14 due to upgrade to 'fork_server' model
178172
self._ctx = mp.get_context("spawn")
179-
self._data_dir = data_dir
180173
self._runnable_factory = runnable_factory
181174

182175
def for_job(self, job: Job) -> Runner[Job, ExecutionEvent]:
@@ -189,4 +182,4 @@ def for_job(self, job: Job) -> Runner[Job, ExecutionEvent]:
189182
Returns:
190183
Runner[Job, ExecutionEvent]: Process-based job runner.
191184
"""
192-
return ProcessRun(self._ctx, self._data_dir, self._runnable_factory, job)
185+
return ProcessRun(self._ctx, self._runnable_factory, job)

application/backend/app/core/jobs/exec/thread_run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class ThreadAwareExecutionContext(ExecutionContext):
101101
def __init__(self, runner: "ThreadRun"):
102102
self.runner = runner
103103

104-
def report_progress(self, message: str = "training", progress: float = 0.0):
104+
def report(self, message: str = "training", progress: float = 0.0):
105105
if not self.runner._cancel_event.is_set():
106106
self.runner._event_queue.put(Progress(message, progress))
107107

application/backend/app/core/models/__init__.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

application/backend/app/core/run/runnable.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from collections.abc import Callable
1818
from dataclasses import dataclass
19-
from pathlib import Path
2019
from typing import Generic, Protocol, TypeVar
2120

2221
ReportFn = Callable[[str, float], None]
@@ -26,14 +25,9 @@
2625
@dataclass(kw_only=True)
2726
class ExecutionContext:
2827
payload: str
29-
data_dir: Path
3028
report: ReportFn
3129
heartbeat: HeartbeatFn
3230

33-
def report_progress(self, msg: str = "", progress: float = 0.0) -> None:
34-
"""Report progress of the execution."""
35-
self.report(msg, progress)
36-
3731

3832
class Runnable(Protocol): # ignore
3933
"""Generic interface for activities executed by runners."""

application/backend/app/lifecycle.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515

1616
from app.core.jobs import JobController, JobQueue, ProcessRunnerFactory
1717
from app.core.run import Runnable, RunnableFactory
18-
from app.db import MigrationManager
18+
from app.db import MigrationManager, get_db_session
1919
from app.scheduler import Scheduler
2020
from app.schemas.job import JobType
2121
from app.services.base_weights_service import BaseWeightsService
2222
from app.services.data_collect import DataCollector
2323
from app.services.event.event_bus import EventBus
2424
from app.services.training import OTXTrainer
25+
from app.services.training.subset_assignment import SubsetAssigner, SubsetService
2526
from app.settings import get_settings
2627
from app.webrtc.manager import WebRTCManager
2728

@@ -30,23 +31,36 @@
3031

3132
def setup_job_controller(data_dir: Path, max_parallel_jobs: int) -> tuple[JobQueue, JobController]:
3233
"""
33-
Set up job controller with queue and processing infrastructure.
34+
Initializes and configures the job queue and job controller for managing parallel job execution.
3435
35-
Creates a job queue and controller with configured parallel job limits and training infrastructure
36-
for job execution.
36+
Sets up the infrastructure to run jobs concurrently and registers classes that comply with the Runnable protocol,
37+
each associated with a job type and its required dependencies. These classes are executed in a context defined
38+
by the runner factory.
3739
3840
Args:
39-
data_dir: Path to the data directory.
41+
data_dir: Path to the directory containing data required for job execution.
4042
max_parallel_jobs (int): Maximum number of jobs that can run concurrently.
4143
4244
Returns:
43-
tuple[JobQueue, JobController]: A tuple containing the job queue instance and the configured job controller.
45+
tuple[JobQueue, JobController]: The job queue and the configured job controller.
4446
"""
4547
q = JobQueue()
4648
job_runnable_factory = RunnableFactory[JobType, Runnable]()
4749
base_weights_service = BaseWeightsService(data_dir=data_dir)
48-
job_runnable_factory.register(JobType.TRAIN, partial(OTXTrainer, base_weights_service=base_weights_service))
49-
process_runner_factory = ProcessRunnerFactory(data_dir, job_runnable_factory)
50+
subset_service = SubsetService()
51+
subset_assigner = SubsetAssigner()
52+
job_runnable_factory.register(
53+
JobType.TRAIN,
54+
partial(
55+
OTXTrainer,
56+
base_weights_service=base_weights_service,
57+
subset_service=subset_service,
58+
subset_assigner=subset_assigner,
59+
data_dir=data_dir,
60+
db_session_factory=get_db_session,
61+
),
62+
)
63+
process_runner_factory = ProcessRunnerFactory(job_runnable_factory)
5064
job_controller = JobController(
5165
jobs_queue=q, runner_factory=process_runner_factory, max_parallel_jobs=max_parallel_jobs
5266
)

application/backend/app/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .dataset_item import DatasetItem, DatasetItemAnnotation, DatasetItemFormat, DatasetItemSubset
55
from .label import Label, LabelReference
66
from .shape import FullImage, Point, Polygon, Rectangle, Shape
7+
from .task_type import TaskType
78

89
__all__ = [
910
"DatasetItem",
@@ -17,4 +18,5 @@
1718
"Polygon",
1819
"Rectangle",
1920
"Shape",
21+
"TaskType",
2022
]

application/backend/app/repositories/dataset_item_repo.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sqlalchemy.orm import Session
99

1010
from app.db.schema import DatasetItemDB, DatasetItemLabelDB
11+
from app.models import DatasetItemSubset
1112

1213

1314
class UpdateDatasetItemAnnotation(NamedTuple):
@@ -128,19 +129,20 @@ def get_subset(self, obj_id: str) -> str | None:
128129
)
129130
return self.db.scalar(stmt)
130131

131-
def set_subset(self, obj_id: str, subset: str) -> None:
132+
def set_subset(self, obj_ids: set[str], subset: str) -> int:
132133
stmt = (
133134
update(DatasetItemDB)
134135
.where(
135136
DatasetItemDB.project_id == self.project_id,
136-
DatasetItemDB.id == obj_id,
137+
DatasetItemDB.id.in_(obj_ids),
137138
)
138139
.values(
139140
subset=subset,
140141
updated_at=datetime.now(UTC),
141142
)
142143
)
143-
self.db.execute(stmt)
144+
result = self.db.execute(stmt)
145+
return result.rowcount or 0
144146

145147
def set_labels(self, dataset_item_id: str, label_ids: set[str]) -> None:
146148
self.delete_labels(dataset_item_id)
@@ -153,3 +155,23 @@ def set_labels(self, dataset_item_id: str, label_ids: set[str]) -> None:
153155
def delete_labels(self, dataset_item_id: str) -> None:
154156
stmt = delete(DatasetItemLabelDB).where(DatasetItemLabelDB.dataset_item_id == dataset_item_id)
155157
self.db.execute(stmt)
158+
159+
def list_unassigned_items(self) -> list[DatasetItemLabelDB]:
160+
stmt = (
161+
select(DatasetItemLabelDB)
162+
.join(DatasetItemDB)
163+
.where(
164+
DatasetItemDB.project_id == self.project_id,
165+
DatasetItemDB.subset == DatasetItemSubset.UNASSIGNED,
166+
)
167+
)
168+
return list(self.db.scalars(stmt).all())
169+
170+
def get_subset_distribution(self) -> dict[str, int]:
171+
stmt = (
172+
select(DatasetItemDB.subset, func.count(DatasetItemDB.id).label("count"))
173+
.where(DatasetItemDB.project_id == self.project_id)
174+
.group_by(DatasetItemDB.subset)
175+
)
176+
result = self.db.execute(stmt)
177+
return {row.subset: row.count for row in result} # type: ignore[misc]

0 commit comments

Comments
 (0)