Skip to content

Commit 1ddd800

Browse files
authored
Fix automatic step values (#9)
1 parent 7675167 commit 1ddd800

File tree

6 files changed

+447
-4
lines changed

6 files changed

+447
-4
lines changed

src/litlogger/api/metrics_api.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,21 @@ def _to_v1_metrics_tracker(tracker: MetricsTracker) -> V1MetricsTracker:
103103
return V1MetricsTracker(**kwargs)
104104

105105

106+
def _from_v1_metrics_tracker(v1_tracker: V1MetricsTracker) -> MetricsTracker:
107+
"""Convert V1MetricsTracker from API response to user-facing MetricsTracker."""
108+
return MetricsTracker(
109+
name=v1_tracker.name,
110+
num_rows=v1_tracker.num_rows or 0,
111+
min_value=v1_tracker.min_value,
112+
max_value=v1_tracker.max_value,
113+
min_index=v1_tracker.min_index,
114+
max_index=v1_tracker.max_index,
115+
last_value=v1_tracker.last_value,
116+
last_index=v1_tracker.last_index,
117+
max_user_step=v1_tracker.max_user_step,
118+
)
119+
120+
106121
def _to_v1_phase_type(phase: PhaseType) -> str:
107122
"""Convert user-facing PhaseType to V1PhaseType string.
108123
@@ -332,3 +347,17 @@ def update_experiment_metrics(
332347
trackers=v1_trackers,
333348
),
334349
)
350+
351+
def get_trackers_from_metrics_store(self, metrics_store: Any) -> Dict[str, MetricsTracker] | None:
352+
"""Extract and convert trackers from a metrics store object.
353+
354+
Args:
355+
metrics_store: The metrics store object from the API.
356+
357+
Returns:
358+
Dictionary of MetricsTracker objects, or None if no trackers exist.
359+
"""
360+
if not hasattr(metrics_store, "trackers") or not metrics_store.trackers:
361+
return None
362+
363+
return {name: _from_v1_metrics_tracker(v1_tracker) for name, v1_tracker in metrics_store.trackers.items()}

src/litlogger/background.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(
6969
store_created_at: bool,
7070
rate_limiting_interval: int = 1,
7171
max_batch_size: int = 1000,
72+
trackers_init: Dict[str, MetricsTracker] | None = None,
7273
) -> None:
7374
super().__init__(daemon=True)
7475
self.teamspace_id = teamspace_id
@@ -98,7 +99,7 @@ def __init__(
9899
client=metrics_api.client,
99100
)
100101

101-
self.trackers: Dict[str, MetricsTracker] = {}
102+
self.trackers: Dict[str, MetricsTracker] = trackers_init if trackers_init is not None else {}
102103

103104
def run(self) -> None:
104105
self._run()
@@ -210,6 +211,10 @@ def _update_tracker(self, name: str, values: Metrics) -> None:
210211

211212
# Increment the number of rows
212213
for value_obj in values.values:
214+
# Augment with step from tracker if not provided
215+
if value_obj.step is None:
216+
value_obj.step = tracker.num_rows
217+
213218
value = float(value_obj.value)
214219

215220
if tracker.started_at is None and self.store_created_at and value_obj.created_at:

src/litlogger/experiment.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import os
1919
import signal
2020
import sys
21-
from collections.abc import Mapping
2221
from concurrent.futures import ThreadPoolExecutor, as_completed
2322
from datetime import datetime
2423
from multiprocessing import JoinableQueue
@@ -165,6 +164,7 @@ def __init__(
165164
store_created_at=store_created_at,
166165
rate_limiting_interval=rate_limiting_interval,
167166
max_batch_size=max_batch_size,
167+
trackers_init=self._metrics_api.get_trackers_from_metrics_store(self._metrics_store),
168168
)
169169

170170
self._manager.start()
@@ -201,7 +201,7 @@ def teamspace(self) -> "Teamspace":
201201
"""
202202
return self._teamspace
203203

204-
def log_metrics(self, metrics: Mapping[str, float], step: int | None = None) -> None:
204+
def log_metrics(self, metrics: Dict[str, float], step: int | None = None, **kwargs: float) -> None:
205205
"""Log metrics to the experiment with background uploading.
206206
207207
Metrics are buffered locally and uploaded to the cloud in batches to optimize performance.
@@ -211,6 +211,8 @@ def log_metrics(self, metrics: Mapping[str, float], step: int | None = None) ->
211211
metrics: Dictionary mapping metric names to numeric values. Example: {"loss": 0.5, "accuracy": 0.95}.
212212
step: Optional step number for this data point (e.g., training step, epoch).
213213
If None and store_step=True, no step is recorded.
214+
kwargs: Additional metric values. Can be used to provide metrics more natural.
215+
Example: loss=0.5, accuracy: 0.95.
214216
215217
Raises:
216218
RuntimeError: If the background thread encountered an error.
@@ -219,6 +221,8 @@ def log_metrics(self, metrics: Mapping[str, float], step: int | None = None) ->
219221
raise self._manager.exception
220222

221223
batch: Dict[str, Metrics] = {}
224+
225+
metrics.update(kwargs)
222226
for name, value in metrics.items():
223227
created_at = None
224228
if self.store_created_at:

tests/integrations/test_standalone.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,73 @@ def test_console_output():
644644
)
645645

646646

647+
@pytest.mark.cloud()
648+
def test_resume_experiment_with_tracker_initialization():
649+
"""Test that resuming an experiment initializes trackers and augments steps correctly."""
650+
experiment_name = f"standalone_resume_tracker-{uuid.uuid4().hex}"
651+
652+
# First experiment run - log metrics with explicit steps
653+
exp1 = litlogger.init(name=experiment_name, teamspace="oss-litlogger")
654+
655+
for i in range(10):
656+
litlogger.log_metrics({"loss": 1.0 - i * 0.1}, step=i)
657+
658+
litlogger.finalize()
659+
660+
# Store info for verification
661+
project_id = exp1._teamspace.id
662+
stream_id = exp1._metrics_store.id
663+
664+
# Wait for metrics to be available
665+
client = LitRestClient()
666+
for _ in range(30):
667+
response = client.lit_logger_service_get_logger_metrics(project_id=project_id, ids=[stream_id])
668+
if response.named_metrics != {}:
669+
metrics = response.named_metrics
670+
if len(metrics.get("loss", {}).ids_metrics.get(stream_id, {}).metrics_values or []) == 10:
671+
break
672+
sleep(1)
673+
674+
# Second experiment run (resume) - log metrics WITHOUT explicit steps
675+
# The steps should be augmented from tracker's num_rows (which should be 10)
676+
exp2 = litlogger.init(name=experiment_name, teamspace="oss-litlogger")
677+
678+
# Verify that the experiment resumed (same stream ID means same experiment)
679+
assert exp2._metrics_store.id == stream_id, "Expected to resume the same experiment"
680+
681+
# Log 5 more metrics WITHOUT explicit steps - they should get steps 10-14
682+
for i in range(5):
683+
litlogger.log_metrics({"loss": 0.05 - i * 0.01}) # No step parameter
684+
685+
litlogger.finalize()
686+
687+
# Wait for all metrics to be available
688+
for _ in range(30):
689+
response = client.lit_logger_service_get_logger_metrics(project_id=project_id, ids=[stream_id])
690+
if response.named_metrics != {}:
691+
metrics = response.named_metrics
692+
loss_values = metrics.get("loss", {}).ids_metrics.get(stream_id, {}).metrics_values or []
693+
if len(loss_values) == 15: # 10 from first run + 5 from second
694+
break
695+
sleep(1)
696+
697+
# Verify we have all 15 metrics
698+
loss_metrics = response.named_metrics["loss"].ids_metrics[stream_id].metrics_values
699+
assert len(loss_metrics) == 15, f"Expected 15 loss metrics, got {len(loss_metrics)}"
700+
701+
# Verify the steps are sequential (0-9 from first run, 10-14 from second run)
702+
# Steps may come back as strings from the API, so convert to int for comparison
703+
steps = sorted([int(m.step) for m in loss_metrics])
704+
expected_steps = list(range(15))
705+
assert steps == expected_steps, f"Expected steps {expected_steps}, got {steps}"
706+
707+
# Cleanup
708+
client.lit_logger_service_delete_metrics_stream(
709+
project_id=project_id,
710+
body=LitLoggerServiceDeleteMetricsStreamBody(ids=[stream_id]),
711+
)
712+
713+
647714
@pytest.mark.cloud()
648715
def test_get_or_create_experiment_metrics():
649716
"""Test get_or_create_experiment_metrics returns existing experiment on second call."""

tests/unittests/api/test_metrics_api.py

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88

99
from lightning_sdk.lightning_cloud.openapi import (
1010
V1Metrics,
11+
V1MetricsTracker,
1112
V1MetricValue,
1213
V1PhaseType,
1314
)
14-
from litlogger.api.metrics_api import MetricsApi
15+
from litlogger.api.metrics_api import MetricsApi, _from_v1_metrics_tracker
1516
from litlogger.types import MetricsTracker, PhaseType
1617

1718

@@ -428,3 +429,136 @@ def test_update_experiment_metrics_custom_phase(self):
428429
call_args = mock_client.lit_logger_service_update_metrics_stream.call_args
429430
assert call_args.kwargs["body"].persisted is False
430431
assert call_args.kwargs["body"].phase == V1PhaseType.RUNNING
432+
433+
434+
class TestFromV1MetricsTracker:
435+
"""Test the _from_v1_metrics_tracker helper function."""
436+
437+
def test_converts_full_tracker(self):
438+
"""Test converting a V1MetricsTracker with all fields set."""
439+
v1_tracker = V1MetricsTracker(
440+
name="loss",
441+
num_rows=100,
442+
min_value=0.1,
443+
max_value=1.0,
444+
min_index=50,
445+
max_index=0,
446+
last_value=0.2,
447+
last_index=99,
448+
max_user_step=99,
449+
)
450+
451+
result = _from_v1_metrics_tracker(v1_tracker)
452+
453+
assert isinstance(result, MetricsTracker)
454+
assert result.name == "loss"
455+
assert result.num_rows == 100
456+
assert result.min_value == 0.1
457+
assert result.max_value == 1.0
458+
assert result.min_index == 50
459+
assert result.max_index == 0
460+
assert result.last_value == 0.2
461+
assert result.last_index == 99
462+
assert result.max_user_step == 99
463+
464+
def test_converts_minimal_tracker(self):
465+
"""Test converting a V1MetricsTracker with only required fields."""
466+
v1_tracker = MagicMock()
467+
v1_tracker.name = "accuracy"
468+
v1_tracker.num_rows = None
469+
v1_tracker.min_value = None
470+
v1_tracker.max_value = None
471+
v1_tracker.min_index = None
472+
v1_tracker.max_index = None
473+
v1_tracker.last_value = None
474+
v1_tracker.last_index = None
475+
v1_tracker.max_user_step = None
476+
477+
result = _from_v1_metrics_tracker(v1_tracker)
478+
479+
assert isinstance(result, MetricsTracker)
480+
assert result.name == "accuracy"
481+
assert result.num_rows == 0 # Defaults to 0 when None
482+
assert result.min_value is None
483+
assert result.max_value is None
484+
485+
def test_converts_tracker_with_zero_num_rows(self):
486+
"""Test converting a tracker with explicit zero num_rows."""
487+
v1_tracker = V1MetricsTracker(
488+
name="metric",
489+
num_rows=0,
490+
)
491+
492+
result = _from_v1_metrics_tracker(v1_tracker)
493+
494+
assert result.num_rows == 0
495+
496+
497+
class TestGetTrackersFromMetricsStore:
498+
"""Test the get_trackers_from_metrics_store method."""
499+
500+
def test_returns_none_when_no_trackers_attribute(self):
501+
"""Test returns None when metrics store has no trackers attribute."""
502+
mock_client = MagicMock()
503+
api = MetricsApi(client=mock_client)
504+
505+
mock_metrics_store = MagicMock(spec=[]) # No attributes
506+
507+
result = api.get_trackers_from_metrics_store(mock_metrics_store)
508+
509+
assert result is None
510+
511+
def test_returns_none_when_trackers_is_none(self):
512+
"""Test returns None when metrics store trackers is None."""
513+
mock_client = MagicMock()
514+
api = MetricsApi(client=mock_client)
515+
516+
mock_metrics_store = MagicMock()
517+
mock_metrics_store.trackers = None
518+
519+
result = api.get_trackers_from_metrics_store(mock_metrics_store)
520+
521+
assert result is None
522+
523+
def test_returns_none_when_trackers_is_empty(self):
524+
"""Test returns None when metrics store trackers is empty dict."""
525+
mock_client = MagicMock()
526+
api = MetricsApi(client=mock_client)
527+
528+
mock_metrics_store = MagicMock()
529+
mock_metrics_store.trackers = {}
530+
531+
result = api.get_trackers_from_metrics_store(mock_metrics_store)
532+
533+
assert result is None
534+
535+
def test_converts_trackers_from_metrics_store(self):
536+
"""Test successfully converts trackers from metrics store."""
537+
mock_client = MagicMock()
538+
api = MetricsApi(client=mock_client)
539+
540+
mock_metrics_store = MagicMock()
541+
mock_metrics_store.trackers = {
542+
"loss": V1MetricsTracker(
543+
name="loss",
544+
num_rows=100,
545+
min_value=0.1,
546+
max_value=1.0,
547+
),
548+
"accuracy": V1MetricsTracker(
549+
name="accuracy",
550+
num_rows=50,
551+
min_value=0.8,
552+
max_value=0.99,
553+
),
554+
}
555+
556+
result = api.get_trackers_from_metrics_store(mock_metrics_store)
557+
558+
assert result is not None
559+
assert len(result) == 2
560+
assert "loss" in result
561+
assert "accuracy" in result
562+
assert isinstance(result["loss"], MetricsTracker)
563+
assert result["loss"].num_rows == 100
564+
assert result["accuracy"].num_rows == 50

0 commit comments

Comments
 (0)