Skip to content

Commit 1172ed6

Browse files
DocGarbanzoclaude
andcommitted
Update documentation and configuration for segment-based performance system
Updates CLAUDE.md with refined segment implementation status, enhances configuration templates with segment feature documentation, and improves test coverage for segment training pipeline integration. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent a4f82b8 commit 1172ed6

File tree

9 files changed

+264
-78
lines changed

9 files changed

+264
-78
lines changed

CLAUDE.md

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -473,9 +473,10 @@ Visualizes recorded vehicle trajectories, detects laps, computes mean reference
473473
courses, and segments courses into geometric features.
474474

475475
**Segment stats:** For Tub data, imupath computes segment performance on the
476-
fly using `FIELD_AGGREGATIONS` and `LAP_SORTING_CRITERIA`. It loads
477-
`./config.py` by default; pass `--config` to use another config and include
478-
custom tub fields in the Segment Stats selector.
476+
fly using `FIELD_AGGREGATIONS` (the single source of truth for both field
477+
aggregation and ranking). It loads `./config.py` by default; pass `--config`
478+
to use another config and include custom tub fields in the Segment Stats
479+
selector.
479480
Web UI segment stats use TubStatistics session rankings from manifest
480481
metadata, so `donkey segment` must have stored segmentation data for the
481482
session.
@@ -602,12 +603,18 @@ lap" that outperforms any single recorded lap.
602603
**Data Structure:**
603604
- Segment assignments stored in manifest metadata (NOT in catalog records)
604605
- Metadata stores: segmentation parameters, segment boundaries, rankings
605-
- Performance rankings: `session_rank[session_id][lap_num][segment_id] =
606-
[time_pct, gyro_z_pct, distance_pct]`
606+
- Performance rankings computed from FIELD_AGGREGATIONS:
607+
`session_rank[session_id][lap_num][segment_id] = {field1_pct, field2_pct, ...}`
608+
- The `lap_pct` vector passed to training matches FIELD_AGGREGATIONS order:
609+
`[time_pct, distance_pct, gyro_z_pct, ...]`
607610

608611
**IMPORTANT:** See "Tub Data Integrity" section - segment data is computed at
609612
training time from manifest metadata, NOT stored in individual records.
610613

614+
**Single source of truth:** `FIELD_AGGREGATIONS` defines both what gets
615+
aggregated AND how laps/segments are ranked. The order of entries determines
616+
ranking priority (first entry is primary sort key).
617+
611618
**Example:** 3 laps, 4 segments per lap
612619

613620
Lap 1: Segments [Fast, Slow, Medium, Fast]
@@ -625,17 +632,49 @@ segment!
625632

626633
### Configuration
627634

628-
`donkeycar/templates/cfg_complete.py`:
635+
**Primary config:** `donkeycar/templates/cfg_donkey5.py` (cfg_complete.py uses
636+
deprecated LAP_SORTING_CRITERIA for backward compatibility)
629637

630638
```python
631-
#SEGMENT PERFORMANCE
632-
SEGMENT_PCT_MODE = False # True = segment-based, False = lap-based
633-
SEGMENT_STRATEGY = 'hybrid' # threshold, extrema, gradient, or hybrid
634-
SEGMENT_LAP_DETECTOR = 'ycrossing' # ycrossing or drift
635-
SEGMENT_MIN_LENGTH = 1.0 # Minimum segment length in meters
636-
SEGMENT_CURVATURE_THRESHOLD = 0.1 # Curvature threshold for segmentation
639+
# Enable segment-based training
640+
SEGMENT_PCT_MODE = True # True = segment-based, False = lap-based
641+
642+
# FIELD_AGGREGATIONS: Single source of truth for:
643+
# 1. Which fields to aggregate per lap/segment
644+
# 2. How to rank laps/segments (order matters!)
645+
# 3. What goes into the lap_pct vector for training
646+
647+
def abs_transform(value):
648+
return abs(value)
649+
650+
FIELD_AGGREGATIONS = [
651+
# Primary ranking: lap/segment time
652+
{'output_key': 'time'}, # Boundary field (no 'field' key)
653+
# Secondary ranking: distance
654+
{'output_key': 'distance'}, # Boundary field
655+
# Tertiary ranking: smoothness via gyro Z-axis
656+
{
657+
'field': 'car/gyro', # Record field
658+
'index': 2,
659+
'output_key': 'gyro_z_agg',
660+
'transform': abs_transform,
661+
'aggregation': 'avg'
662+
}
663+
]
664+
665+
# To train using ONLY time and distance (no behavioral metrics):
666+
# FIELD_AGGREGATIONS = [
667+
# {'output_key': 'time'},
668+
# {'output_key': 'distance'}
669+
# ]
637670
```
638671

672+
**Field types:**
673+
- **Boundary fields**: Computed from lap/segment timing (time, distance).
674+
No 'field' key.
675+
- **Record fields**: Extracted from tub records (gyro, accel, speed). Have
676+
'field' key.
677+
639678
### Iterative Training Strategy
640679

641680
1. Train with segment_pct on initial multi-lap data

donkeycar/parts/tub_statistics.py

Lines changed: 90 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,32 @@ def accumulate_fields(self, record):
125125

126126
@dataclass
127127
class FieldAggregationSpec:
128-
"""Specification for aggregating a field across lap/segment."""
129-
field: str # e.g., 'car/gyro'
130-
output_key: str # e.g., 'gyro_z_agg'
128+
"""
129+
Specification for aggregating a field across lap/segment.
130+
131+
Two types of fields:
132+
1. Boundary fields: computed from lap/segment boundaries (time, distance)
133+
- No 'field' attribute (field=None)
134+
- Computed in _finalize_segment_instance()
135+
2. Record fields: extracted from individual records (gyro, accel, etc.)
136+
- Has 'field' attribute (e.g., 'car/gyro')
137+
- Aggregated across records in _aggregate_single_field()
138+
"""
139+
output_key: str # e.g., 'gyro_z_agg' or 'time'
140+
field: Optional[str] = None # e.g., 'car/gyro' (None for boundary)
131141
index: Optional[int] = None # Vector index (None for scalars)
132142
transform: Optional[Callable] = None # Transform function
133143
aggregation: str = 'avg' # avg, sum, min, max, median, delta
144+
reverse: bool = False # For sorting: True = descending
145+
146+
def is_boundary_field(self) -> bool:
147+
"""Check if this is a boundary field (time/distance)."""
148+
return self.field is None
134149

135150
def extract(self, record: dict) -> Optional[float]:
136151
"""Extract and transform value from record."""
152+
if self.is_boundary_field():
153+
return None # Boundary fields are not extracted from records
137154
try:
138155
value = record[self.field]
139156
if self.index is not None:
@@ -193,9 +210,10 @@ def __init__(self,
193210
Construct tub statistics calculator for tub
194211
195212
:param tub: input tub
196-
:param config: Config object (loads FIELD_AGGREGATIONS,
197-
LAP_SORTING_CRITERIA). Required if
198-
field_aggregations not provided.
213+
:param config: Config object (loads FIELD_AGGREGATIONS).
214+
FIELD_AGGREGATIONS is the single source of
215+
truth for both aggregation and ranking.
216+
Required if field_aggregations not provided.
199217
:param sorting_strategy: Optional custom sorting strategy
200218
:param field_aggregations: Optional list of FieldAggregationSpec or
201219
dicts. Required if config not provided.
@@ -233,7 +251,8 @@ def _normalize_field_aggregations(self, field_aggregations: List) -> List[
233251
FieldAggregationSpec]:
234252
"""Convert dict specs to FieldAggregationSpec.
235253
236-
:raises ValueError: If old-style extractor syntax is used.
254+
:raises ValueError: If old-style extractor syntax is used or required
255+
fields are missing.
237256
"""
238257
normalized = []
239258
for spec in field_aggregations:
@@ -245,12 +264,17 @@ def _normalize_field_aggregations(self, field_aggregations: List) -> List[
245264
f'Old-style field_aggregations with "extractor" '
246265
f'not supported for field {spec.get("field", "?")}. '
247266
f'Use "index" parameter instead.')
267+
if 'output_key' not in spec:
268+
raise ValueError(
269+
f'Field aggregation missing required "output_key": '
270+
f'{spec}')
248271
normalized.append(FieldAggregationSpec(
249-
field=spec['field'],
250272
output_key=spec['output_key'],
273+
field=spec.get('field'), # None for boundary fields
251274
index=spec.get('index'),
252275
transform=spec.get('transform'),
253-
aggregation=spec.get('aggregation', 'avg')
276+
aggregation=spec.get('aggregation', 'avg'),
277+
reverse=spec.get('reverse', False)
254278
))
255279
return normalized
256280

@@ -266,37 +290,60 @@ def _load_field_aggregations_from_config(self, config) -> List[
266290
raise ValueError(
267291
'FIELD_AGGREGATIONS not found in config. '
268292
'Please define FIELD_AGGREGATIONS in your config file. '
269-
'Example: FIELD_AGGREGATIONS = [{"field": "car/gyro", '
270-
'"output_key": "gyro_z_agg", "index": 1, "aggregation": "avg"}]')
271-
272-
# Convert config dicts to FieldAggregationSpec
273-
specs = []
274-
for spec_dict in config_specs:
275-
spec = FieldAggregationSpec(
276-
field=spec_dict['field'],
277-
output_key=spec_dict['output_key'],
278-
index=spec_dict.get('index'),
279-
transform=spec_dict.get('transform'),
280-
aggregation=spec_dict.get('aggregation', 'avg')
281-
)
282-
specs.append(spec)
283-
logger.info(f'Loaded field aggregation: {spec.output_key} from '
284-
f'{spec.field}[{spec.index}] using {spec.aggregation}')
293+
'Example: FIELD_AGGREGATIONS = [\n'
294+
' {"output_key": "time"}, # Boundary field\n'
295+
' {"output_key": "distance"},\n'
296+
' {"field": "car/gyro", "output_key": "gyro_z_agg", '
297+
'"index": 2, "aggregation": "avg"}\n'
298+
']')
299+
300+
# Use normalization method for consistency
301+
specs = self._normalize_field_aggregations(config_specs)
302+
303+
for spec in specs:
304+
if spec.is_boundary_field():
305+
logger.info(f'Loaded boundary field: {spec.output_key}')
306+
else:
307+
logger.info(
308+
f'Loaded field aggregation: {spec.output_key} from '
309+
f'{spec.field}[{spec.index}] using {spec.aggregation}')
285310

286311
return specs
287312

288313
def _load_sorting_strategy_from_config(self, config) -> SortingStrategy:
289-
"""Load sorting strategy from config."""
290-
criteria = getattr(config, 'LAP_SORTING_CRITERIA', None)
291-
if criteria:
292-
logger.info(f'Loaded sorting criteria from config: '
293-
f'{[c["key"] for c in criteria]}')
314+
"""
315+
Load sorting strategy from config.
316+
317+
Strategy is built from FIELD_AGGREGATIONS (single source of truth).
318+
For backward compatibility, falls back to LAP_SORTING_CRITERIA if found.
319+
"""
320+
# Check for deprecated LAP_SORTING_CRITERIA
321+
old_criteria = getattr(config, 'LAP_SORTING_CRITERIA', None)
322+
if old_criteria:
323+
logger.warning(
324+
'LAP_SORTING_CRITERIA is DEPRECATED. '
325+
'Use FIELD_AGGREGATIONS instead as the single source of truth. '
326+
'Add time/distance as boundary fields: '
327+
'{"output_key": "time"}, {"output_key": "distance"}')
328+
return SortingStrategy(old_criteria)
329+
330+
# Build strategy from FIELD_AGGREGATIONS
331+
if self.field_aggregations:
332+
criteria = []
333+
for spec in self.field_aggregations:
334+
criteria.append({
335+
'key': spec.output_key,
336+
'transform': spec.transform or (lambda x: x),
337+
'reverse': spec.reverse
338+
})
339+
logger.info(
340+
f'Built sorting strategy from FIELD_AGGREGATIONS: '
341+
f'{[c["key"] for c in criteria]}')
294342
return SortingStrategy(criteria)
295-
else:
296-
logger.info('No LAP_SORTING_CRITERIA in config, using minimal '
297-
'defaults (time, distance). Configure LAP_SORTING_CRITERIA '
298-
'in config to include custom fields like gyro_z_agg.')
299-
return default_lap_sorting_strategy()
343+
344+
# Should never reach here due to validation in __init__
345+
logger.error('No field aggregations available for sorting strategy')
346+
return default_lap_sorting_strategy()
300347

301348
def generate_laptimes_from_records(self, overwrite=False):
302349

@@ -803,11 +850,16 @@ def _calculate_aggregated_fields(self):
803850
804851
Generic implementation that handles any field with custom
805852
extractor and transform functions.
853+
854+
Boundary fields (time, distance) are skipped here - they're
855+
computed in _finalize_segment_instance().
806856
"""
807-
logger.info(f'Calculating {len(self.field_aggregations)} field '
808-
f'aggregations in tub {self.tub.base_path}')
857+
record_fields = [spec for spec in self.field_aggregations
858+
if not spec.is_boundary_field()]
859+
logger.info(f'Calculating {len(record_fields)} field aggregations '
860+
f'from records in tub {self.tub.base_path}')
809861

810-
for field_spec in self.field_aggregations:
862+
for field_spec in record_fields:
811863
self._aggregate_single_field(field_spec)
812864

813865
def _aggregate_single_field(self, spec: FieldAggregationSpec):

donkeycar/pipeline/training.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,20 @@ def train(cfg: Config, tub_paths: str, model: str = None,
144144
elif add_lap_pct or getattr(cfg, 'LAP_QUANTIFIER', None) is not None:
145145
pct_mode = PctMode.LAP
146146

147+
# Extract ranking keys from FIELD_AGGREGATIONS (single source of truth)
148+
ranking_keys = None
149+
if hasattr(cfg, 'FIELD_AGGREGATIONS') and cfg.FIELD_AGGREGATIONS:
150+
ranking_keys = [spec['output_key'] for spec in cfg.FIELD_AGGREGATIONS]
151+
logger.info(f'Extracted ranking keys from FIELD_AGGREGATIONS: '
152+
f'{ranking_keys}')
153+
else:
154+
logger.warning('No FIELD_AGGREGATIONS in config - lap_pct will not be '
155+
'populated. Define FIELD_AGGREGATIONS in your config.')
156+
147157
dataset = TubDataset(config=cfg, tub_paths=all_tub_paths,
148158
seq_size=kl.seq_size(),
149159
add_lap_pct=add_lap_pct,
160+
ranking_keys=ranking_keys,
150161
pct_mode=pct_mode)
151162
train_records, val_records \
152163
= train_test_split(dataset.get_records(), shuffle=True,

donkeycar/pipeline/types.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,10 @@ def extend(self, session_lap_rank, ranking_keys=None,
188188
For LAP mode: {session_id: {lap: rankings}}
189189
For SEGMENT mode: {session_id: {lap: {segment:
190190
rankings}}}
191-
:param ranking_keys: Optional list of keys to extract for lap_pct.
192-
If None, uses default ('time', 'distance',
193-
'gyro_z_agg')
194-
for backward compatibility.
191+
:param ranking_keys: List of keys to extract for lap_pct. Should match
192+
output_key values from FIELD_AGGREGATIONS in config.
193+
If None, auto-extracts keys from first available
194+
ranking dict (backward compatibility).
195195
:param pct_mode: Performance mode (NONE, LAP, or SEGMENT)
196196
:param segment_id: Pre-computed segment ID (for SEGMENT mode).
197197
If provided, uses instead of reading record.
@@ -202,9 +202,12 @@ def extend(self, session_lap_rank, ranking_keys=None,
202202
session_id = self.underlying['_session_id']
203203
lap_i = self.underlying.get('car/lap', 0)
204204

205-
# Use default keys for backward compatibility
206-
if ranking_keys is None:
207-
ranking_keys = ('time', 'distance', 'gyro_z_agg')
205+
# Auto-extract ranking_keys if not provided (backward compatibility)
206+
if ranking_keys is None and session_lap_rank:
207+
ranking_keys = self._extract_ranking_keys(session_lap_rank,
208+
pct_mode, segment_id)
209+
if ranking_keys is None:
210+
return False # Couldn't determine keys
208211

209212
if pct_mode == PctMode.SEGMENT:
210213
# Use passed segment_id if provided, otherwise read from record
@@ -237,6 +240,34 @@ def extend(self, session_lap_rank, ranking_keys=None,
237240

238241
return False # Couldn't populate lap_pct, exclude from training
239242

243+
def _extract_ranking_keys(self, session_lap_rank, pct_mode, segment_id):
244+
"""
245+
Extract ranking keys from session_lap_rank dict (backward compat).
246+
247+
:return: List of ranking keys or None if can't be determined
248+
"""
249+
session_id = self.underlying['_session_id']
250+
lap_i = self.underlying.get('car/lap', 0)
251+
252+
try:
253+
if pct_mode == PctMode.SEGMENT:
254+
# Try to get segment rankings
255+
if segment_id is None:
256+
segment_id = self.underlying.get('car/segment')
257+
if segment_id is None:
258+
return None
259+
lap_dict = session_lap_rank.get(session_id, {}).get(lap_i)
260+
if lap_dict and segment_id in lap_dict:
261+
return list(lap_dict[segment_id].keys())
262+
else:
263+
# LAP mode or fallback
264+
lap_i_dict = session_lap_rank.get(session_id, {}).get(lap_i)
265+
if lap_i_dict and isinstance(lap_i_dict, dict):
266+
return list(lap_i_dict.keys())
267+
except (KeyError, TypeError, AttributeError):
268+
pass
269+
return None
270+
240271
def __repr__(self) -> str:
241272
return repr(self.underlying)
242273

0 commit comments

Comments
 (0)