Skip to content

Commit fa57bad

Browse files
authored
feat: Task Execution Stats + Metric Agg Mode (ENG-1690) (#22)
* Add metric agg mode locally. Add task execution stats. Clean up dependencies. * Remove readme section from pyproject.toml. Typing and linting fixes. * More dependency fixes and aligning with linting/mypy * Refactor CI job * Fixes for CI job * transformers dependency * transformers dependency fix * Fixing typing for transformers integration * More CI workflow adjustments
1 parent c443c14 commit fa57bad

File tree

12 files changed

+1454
-174
lines changed

12 files changed

+1454
-174
lines changed

.github/workflows/semgrep.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
---
2-
name: 🚨 Semgrep Analysis
2+
name: Semgrep Analysis
33
on:
44
merge_group:
55
pull_request:
@@ -28,7 +28,7 @@ permissions:
2828

2929
jobs:
3030
semgrep:
31-
name: 🚨 Semgrep Analysis
31+
name: Semgrep Analysis
3232
runs-on: ubuntu-latest
3333
container:
3434
image: returntocorp/semgrep

.github/workflows/test.yaml

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
name: Tests
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
9+
jobs:
10+
python:
11+
name: Python - Lint, Typecheck, Test
12+
13+
strategy:
14+
fail-fast: false
15+
matrix:
16+
python-version: ["3.10", "3.11", "3.12"]
17+
18+
runs-on: ubuntu-latest
19+
20+
steps:
21+
- name: Checkout code
22+
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
23+
24+
- name: Setup Python ${{ matrix.python-version }}
25+
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b
26+
with:
27+
python-version: ${{ matrix.python-version }}
28+
29+
- name: Install Poetry
30+
uses: abatilo/actions-poetry@e78f54a89cb052fff327414dd9ff010b5d2b4dbd
31+
32+
- name: Configure Poetry
33+
run: |
34+
poetry config virtualenvs.create true --local
35+
poetry config virtualenvs.in-project true --local
36+
37+
- name: Cache dependencies
38+
uses: actions/cache@0c907a75c2c80ebcb7f088228285e798b750cf8f # v4.2.1
39+
with:
40+
path: ./.venv
41+
key: venv-${{ runner.os }}-py${{ matrix.python-version }}-${{ hashFiles('poetry.lock') }}
42+
restore-keys: |
43+
venv-${{ runner.os }}-py${{ matrix.python-version }}-
44+
45+
- name: Install package
46+
run: poetry install --all-extras
47+
48+
- name: Lint
49+
run: poetry run ruff check --output-format=github .
50+
51+
- name: Typecheck
52+
run: poetry run mypy .
53+
54+
- name: Test
55+
run: poetry run pytest

.github/workflows/tests.yaml

Lines changed: 0 additions & 58 deletions
This file was deleted.

dreadnode/artifact/tree_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def _build_tree_structure(
262262
}
263263
dir_structure[root_dir_path] = root_node
264264

265-
for file_path in file_nodes_by_path: # noqa: PLC0206
265+
for file_path, file_node in file_nodes_by_path.items():
266266
try:
267267
rel_path = file_path.relative_to(base_dir)
268268
parts = rel_path.parts
@@ -272,7 +272,7 @@ def _build_tree_structure(
272272

273273
# File in the root directory
274274
if len(parts) == 1:
275-
root_node["children"].append(file_nodes_by_path[file_path])
275+
root_node["children"].append(file_node)
276276
continue
277277

278278
# Create parent directories
@@ -295,7 +295,7 @@ def _build_tree_structure(
295295
# Now add the file to its parent directory
296296
parent_dir_str = file_path.parent.resolve().as_posix()
297297
if parent_dir_str in dir_structure:
298-
dir_structure[parent_dir_str]["children"].append(file_nodes_by_path[file_path])
298+
dir_structure[parent_dir_str]["children"].append(file_node)
299299
self._compute_directory_hashes(dir_structure)
300300

301301
return root_node

dreadnode/integrations/transformers.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,8 @@
55

66
import typing as t
77

8-
from transformers.trainer_callback import ( # type: ignore [import-untyped]
9-
TrainerCallback,
10-
TrainerControl,
11-
TrainerState,
12-
TrainingArguments,
13-
)
8+
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
9+
from transformers.training_args import TrainingArguments
1410

1511
import dreadnode as dn
1612

@@ -28,7 +24,7 @@ def _clean_keys(data: dict[str, t.Any]) -> dict[str, t.Any]:
2824
return cleaned
2925

3026

31-
class DreadnodeCallback(TrainerCallback): # type: ignore [misc]
27+
class DreadnodeCallback(TrainerCallback):
3228
"""
3329
An implementation of the `TrainerCallback` interface for Dreadnode.
3430
@@ -124,7 +120,7 @@ def on_epoch_begin(
124120
control: TrainerControl,
125121
**kwargs: t.Any,
126122
) -> None:
127-
if self._run is None:
123+
if self._run is None or state.epoch is None:
128124
return
129125

130126
dn.log_metric("epoch", state.epoch)

dreadnode/main.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
ENV_SERVER,
3333
ENV_SERVER_URL,
3434
)
35-
from dreadnode.metric import Metric, Scorer, ScorerCallable, T
35+
from dreadnode.metric import Metric, MetricMode, Scorer, ScorerCallable, T
3636
from dreadnode.task import P, R, Task
3737
from dreadnode.tracing.exporters import (
3838
FileExportConfig,
@@ -757,6 +757,7 @@ def log_metric(
757757
step: int = 0,
758758
origin: t.Any | None = None,
759759
timestamp: datetime | None = None,
760+
mode: MetricMode = "direct",
760761
to: ToObject = "task-or-run",
761762
) -> None:
762763
"""
@@ -778,6 +779,14 @@ def log_metric(
778779
origin: The origin of the metric - can be provided any object which was logged
779780
as an input or output anywhere in the run.
780781
timestamp: The timestamp of the metric - defaults to the current time.
782+
mode: The aggregation mode to use for the metric. Helpful when you want to let
783+
the library take care of translating your raw values into better representations.
784+
- direct: do not modify the value at all (default)
785+
- min: the lowest observed value reported for this metric
786+
- max: the highest observed value reported for this metric
787+
- avg: the average of all reported values for this metric
788+
- sum: the cumulative sum of all reported values for this metric
789+
- count: increment every time this metric is logged - disregard value
781790
to: The target object to log the metric to. Can be "task-or-run" or "run".
782791
Defaults to "task-or-run". If "task-or-run", the metric will be logged
783792
to the current task or run, whichever is the nearest ancestor.
@@ -790,6 +799,7 @@ def log_metric(
790799
value: Metric,
791800
*,
792801
origin: t.Any | None = None,
802+
mode: MetricMode = "direct",
793803
to: ToObject = "task-or-run",
794804
) -> None:
795805
"""
@@ -809,11 +819,18 @@ def log_metric(
809819
value: The metric object.
810820
origin: The origin of the metric - can be provided any object which was logged
811821
as an input or output anywhere in the run.
822+
mode: The aggregation mode to use for the metric. Helpful when you want to let
823+
the library take care of translating your raw values into better representations.
824+
- direct: do not modify the value at all (default)
825+
- min: always report the lowest ovbserved value for this metric
826+
- max: always report the highest observed value for this metric
827+
- avg: report the average of all values for this metric
828+
- sum: report a rolling sum of all values for this metric
829+
- count: report the number of times this metric has been logged
812830
to: The target object to log the metric to. Can be "task-or-run" or "run".
813831
Defaults to "task-or-run". If "task-or-run", the metric will be logged
814832
to the current task or run, whichever is the nearest ancestor.
815833
"""
816-
... # noqa: PIE790
817834

818835
@handle_internal_errors()
819836
def log_metric(
@@ -824,6 +841,7 @@ def log_metric(
824841
step: int = 0,
825842
origin: t.Any | None = None,
826843
timestamp: datetime | None = None,
844+
mode: MetricMode = "direct",
827845
to: ToObject = "task-or-run",
828846
) -> None:
829847
task = current_task_span.get()
@@ -838,7 +856,7 @@ def log_metric(
838856
if isinstance(value, Metric)
839857
else Metric(float(value), step, timestamp or datetime.now(timezone.utc))
840858
)
841-
target.log_metric(key, metric, origin=origin)
859+
target.log_metric(key, metric, origin=origin, mode=mode)
842860

843861
@handle_internal_errors()
844862
def log_artifact(

dreadnode/metric.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
T = t.TypeVar("T")
1212

13+
MetricMode = t.Literal["direct", "avg", "sum", "min", "max", "count"]
14+
1315

1416
@dataclass
1517
class Metric:
@@ -55,6 +57,46 @@ def from_many(
5557
score_attributes = {name: value for name, value, _ in values}
5658
return cls(value=total / weight, step=step, attributes={**attributes, **score_attributes})
5759

60+
def apply_mode(self, mode: MetricMode, others: "list[Metric]") -> "Metric":
61+
"""
62+
Apply an aggregation mode to the metric.
63+
This will modify the metric in place.
64+
65+
Args:
66+
mode: The mode to apply. One of "sum", "min", "max", or "inc".
67+
others: A list of other metrics to apply the mode to.
68+
69+
Returns:
70+
self
71+
"""
72+
previous_mode = next((m.attributes.get("mode") for m in others), mode) or "direct"
73+
if mode != previous_mode:
74+
raise ValueError(
75+
f"Cannot mix metric modes {mode} != {previous_mode}",
76+
)
77+
78+
if mode == "direct":
79+
return self
80+
81+
self.attributes["original"] = self.value
82+
self.attributes["mode"] = mode
83+
84+
prior_values = [m.value for m in sorted(others, key=lambda m: m.timestamp)]
85+
86+
if mode == "sum":
87+
self.value += max(prior_values)
88+
elif mode == "min":
89+
self.value = min([self.value, *prior_values])
90+
elif mode == "max":
91+
self.value = max([self.value, *prior_values])
92+
elif mode == "count":
93+
self.value = len(others) + 1
94+
elif mode == "avg" and prior_values:
95+
current_avg = prior_values[-1]
96+
self.value = current_avg + (self.value - current_avg) / (len(prior_values) + 1)
97+
98+
return self
99+
58100

59101
MetricDict = dict[str, list[Metric]]
60102

@@ -83,7 +125,7 @@ class Scorer(t.Generic[T]):
83125
def from_callable(
84126
cls,
85127
tracer: Tracer,
86-
func: ScorerCallable[T] | "Scorer[T]", # noqa: TC010
128+
func: "ScorerCallable[T] | Scorer[T]",
87129
*,
88130
name: str | None = None,
89131
tags: t.Sequence[str] | None = None,

dreadnode/task.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def top_n(
8383
"""
8484
sorted_ = self.sorted(reverse=reverse)[:n]
8585
return (
86-
t.cast(list[R], [span.output for span in sorted_]) # noqa: TC006
86+
t.cast("list[R]", [span.output for span in sorted_])
8787
if as_outputs
8888
else TaskSpanList(sorted_)
8989
)
@@ -246,6 +246,8 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]:
246246
run_id=run.run_id,
247247
tracer=self.tracer,
248248
) as span:
249+
span.run.log_metric(f"{self.label}.exec.count", 1, mode="count")
250+
249251
for name, value in params_to_log.items():
250252
span.log_param(name, value)
251253

@@ -254,10 +256,15 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]:
254256
for name, value in inputs_to_log.items()
255257
]
256258

257-
output = t.cast(R | t.Awaitable[R], self.func(*args, **kwargs)) # noqa: TC006
258-
if inspect.isawaitable(output):
259-
output = await output
259+
try:
260+
output = t.cast("R | t.Awaitable[R]", self.func(*args, **kwargs))
261+
if inspect.isawaitable(output):
262+
output = await output
263+
except Exception:
264+
span.run.log_metric(f"{self.label}.exec.success_rate", 0, mode="avg")
265+
raise
260266

267+
span.run.log_metric(f"{self.label}.exec.success_rate", 1, mode="avg")
261268
span.output = output
262269

263270
if self.log_output:

0 commit comments

Comments
 (0)