Skip to content

Commit 3d49f94

Browse files
rmitschRaphael Mitsch
andauthored
refactor: Drop distillation task; allow per Task.distill() only. (#173)
* refactor: Drop distillation task; allow per `Task.distill()` only. * refactor: Drop `train_frac` in distill(). * fix: Remove `train_frac` specs. --------- Co-authored-by: Raphael Mitsch <[email protected]>
1 parent f21b346 commit 3d49f94

File tree

19 files changed

+188
-358
lines changed

19 files changed

+188
-358
lines changed

AGENTS.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Key packages and concepts: `sieves.data.Doc`, `sieves.pipeline.Pipeline`, `sieve
4242

4343
## Environments & Installation
4444

45-
Supported Python: `>=3.10` (tests and typing target 3.11).
45+
Supported Python: `>=3.12`.
4646

4747
Using `uv` (preferred):
4848
- Base: `uv sync`
@@ -92,7 +92,7 @@ Notes for agents:
9292
- Type checking: mypy strict (`[tool.mypy]` in `pyproject.toml`)
9393
- Linting: ruff (E, F, I, UP), isort via ruff
9494
- Formatting: black (line length 120)
95-
- Python target version: 3.10 for style/format; 3.11 for mypy config
95+
- Python target version: 3.12
9696
- Avoid one‑letter variable names; keep changes minimal and focused
9797

9898
## Development Commands

sieves/engines/engine_import.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
"""
2-
Imports 3rd-party libraries required for engines. If library can't be found, placeholder engines is imported instead.
1+
"""Import 3rd-party libraries required for engines.
2+
3+
If library can't be found, placeholder engines is imported instead.
4+
35
This allows us to import everything downstream without having to worry about optional dependencies. If a user specifies
46
an engine/model from a non-installed library, we terminate with an error.
57
"""
@@ -93,8 +95,8 @@
9395
_missing_dependencies.append("vllm")
9496

9597
warnings.warn(
96-
"Warning: engine dependencies [{deps}] could not be imported. The corresponding engines won't work "
97-
"unless this dependency has been installed.".format(deps=", ".join(_missing_dependencies))
98+
"Warning: structured generation dependencies [{deps}] could not be imported. Generating with them requires them to"
99+
" be installed.".format(deps=", ".join(_missing_dependencies))
98100
)
99101

100102

sieves/pipeline/core.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import copy
66
import itertools
7-
import typing
87
from collections.abc import Iterable, Iterator, Sized
98
from pathlib import Path
109
from typing import Any
@@ -13,7 +12,7 @@
1312

1413
from sieves.data import Doc
1514
from sieves.serialization import Attribute, Config, Serializable
16-
from sieves.tasks import Distillation, PredictiveTask, Task
15+
from sieves.tasks import Task
1716

1817

1918
class Pipeline:
@@ -36,7 +35,6 @@ def __init__(
3635
self._cache: dict[int, Doc] = {}
3736
self._cache_stats: dict[str, int] = {"total": 0, "unique": 0, "hits": 0, "misses": 0}
3837
self._validate_tasks()
39-
self._set_distillation_targets()
4038

4139
def add_tasks(self, tasks: Iterable[Task]) -> None:
4240
"""Add tasks to pipeline. Revalidates pipeline.
@@ -45,7 +43,6 @@ def add_tasks(self, tasks: Iterable[Task]) -> None:
4543
"""
4644
self._tasks.extend(tasks)
4745
self._validate_tasks()
48-
self._set_distillation_targets()
4946

5047
@property
5148
def tasks(self) -> list[Task]:
@@ -75,17 +72,6 @@ def _validate_tasks(self) -> None:
7572
raise ValueError(f"Task with duplicate ID {task.id}. Ensure unique task IDs.")
7673
task_ids.add(task.id)
7774

78-
def _set_distillation_targets(self) -> None:
79-
"""Set target task references fpr distillation tasks, if there are any.
80-
81-
This is necessary because distillation tasks have a lazily initialized required attribute.
82-
"""
83-
for task in self._tasks:
84-
if isinstance(task, Distillation):
85-
target_task = self[task.target_task_id]
86-
assert issubclass(type(target_task), PredictiveTask)
87-
task.target_task = typing.cast(PredictiveTask, target_task) # type: ignore[type-arg]
88-
8975
def _get_unseen_unique_docs(self, docs: Iterable[Doc]) -> Iterable[Doc]:
9076
"""Yield unseen, unique docs.
9177
@@ -273,5 +259,5 @@ def __iadd__(self, other: Task | Pipeline) -> Pipeline:
273259
else:
274260
raise TypeError(f"Can only add Task or Pipeline to Pipeline with +=, got {type(other).__name__}")
275261
self._validate_tasks()
276-
self._set_distillation_targets()
262+
277263
return self

sieves/tasks/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from . import predictive, preprocessing
44
from .core import Task
5-
from .postprocessing import Distillation, DistillationFramework
5+
from .postprocessing import DistillationFramework
66
from .predictive import (
77
NER,
88
Classification,
@@ -19,7 +19,6 @@
1919
__all__ = [
2020
"Chunking",
2121
"Classification",
22-
"Distillation",
2322
"DistillationFramework",
2423
"NER",
2524
"InformationExtraction",
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1-
from .distillation import Distillation, DistillationFramework
1+
"""Postprocessing tasks."""
22

3-
__all__ = ["Distillation", "DistillationFramework"]
3+
from .distillation import DistillationFramework
4+
5+
__all__ = ["DistillationFramework"]
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from .core import Distillation
1+
"""Distillation."""
2+
23
from .types import DistillationFramework, DistillationFrameworkLiteral
34

4-
__all__ = ["Distillation", "DistillationFramework", "DistillationFrameworkLiteral"]
5+
__all__ = ["DistillationFramework", "DistillationFrameworkLiteral"]

sieves/tasks/postprocessing/distillation/core.py

Lines changed: 0 additions & 132 deletions
This file was deleted.
Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
"""
2-
Imports 3rd-party libraries required for distillation. If library can't be found, placeholder engines is imported
3-
instead.
1+
"""Import 3rd-party libraries required for distillation.
2+
3+
If library can't be found, placeholder engines is imported instead.
4+
45
This allows us to import everything downstream without having to worry about optional dependencies. If a user specifies
56
a non-installed distillation framework, we terminate with an error.
67
"""
@@ -9,34 +10,34 @@
910

1011
import warnings
1112

12-
_MISSING_WARNING = (
13-
"Warning: engine dependency `{missing_dependency}` could not be imported. The corresponding engines won't work "
14-
"unless this dependency has been installed."
15-
)
13+
_missing_dependencies: list[str] = []
1614

1715

1816
try:
1917
import sentence_transformers
2018
except ModuleNotFoundError:
2119
sentence_transformers = None
2220

23-
warnings.warn(_MISSING_WARNING.format(missing_dependency="sentence_transformers"))
24-
21+
_missing_dependencies.append("sentence_transformers")
2522

2623
try:
2724
import setfit
2825
except ModuleNotFoundError:
2926
setfit = None
3027

31-
warnings.warn(_MISSING_WARNING.format(missing_dependency="setfit"))
28+
_missing_dependencies.append("setfit")
3229

3330
try:
3431
import model2vec
3532
import model2vec.train
3633
except ModuleNotFoundError:
3734
model2vec = None
3835

39-
warnings.warn(_MISSING_WARNING.format(missing_dependency="model2vec"))
36+
_missing_dependencies.append("model2vec")
4037

38+
warnings.warn(
39+
"Warning: distillation dependency [{deps}] could not be imported. Distilling with these tools requires them to "
40+
"be installed.".format(deps=", ".join(_missing_dependencies))
41+
)
4142

4243
__all__ = ["model2vec", "sentence_transformers", "setfit"]

0 commit comments

Comments
 (0)