diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6f04f489d35..5ebeceed5d2 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -60,7 +60,53 @@ jobs: scons -j8 cd opendbc/safety/tests && ./mutation.sh - # TODO: this test needs to move to opendbc + car_diff: + name: car diff + runs-on: ${{ github.repository == 'commaai/opendbc' && 'namespace-profile-amd64-8x16' || 'ubuntu-latest' }} + env: + GIT_REF: ${{ github.event_name == 'push' && github.ref == format('refs/heads/{0}', github.event.repository.default_branch) && github.event.before || format('origin/{0}', github.event.repository.default_branch) }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: ./.github/workflows/cache + - name: Build opendbc + run: | + source setup.sh + scons -j8 + - name: Test car diff + if: github.event_name == 'pull_request' + run: source setup.sh && python opendbc/car/tests/car_diff.py | tee diff.txt + - name: Comment PR + if: always() && github.event_name == 'pull_request' + env: + GH_TOKEN: ${{ github.token }} + run: '[ -s diff.txt ] && gh pr comment ${{ github.event.pull_request.number }} --repo ${{ github.repository }} -F diff.txt || true' + - name: Update refs + if: github.repository == 'commaai/opendbc' && github.ref == 'refs/heads/master' + run: source setup.sh && python opendbc/car/tests/car_diff.py --update-refs + - name: Checkout ci-artifacts + if: github.repository == 'commaai/opendbc' && github.ref == 'refs/heads/master' + uses: actions/checkout@v4 + with: + repository: commaai/ci-artifacts + ssh-key: ${{ secrets.CI_ARTIFACTS_DEPLOY_KEY }} + path: ${{ github.workspace }}/ci-artifacts + - name: Push refs + if: github.repository == 'commaai/opendbc' && github.ref == 'refs/heads/master' + working-directory: ${{ github.workspace }}/ci-artifacts + run: | + ls ${{ github.workspace }}/car_diff/*.zst 2>/dev/null || exit 0 + git config user.name "GitHub Actions Bot" + git config user.email "<>" + git fetch origin car_diff || true + git checkout car_diff 2>/dev/null || git checkout --orphan car_diff + cp ${{ github.workspace }}/car_diff/*.zst . + git add *.zst + git commit -m "car_diff refs for ${{ github.sha }}" || echo "No changes to commit" + git push origin car_diff + + # TODO: this needs to move to opendbc test_models: name: test models runs-on: ${{ github.repository == 'commaai/opendbc' && 'namespace-profile-amd64-8x16' || 'ubuntu-latest' }} diff --git a/opendbc/car/logreader.py b/opendbc/car/logreader.py new file mode 100644 index 00000000000..93b2359fb5b --- /dev/null +++ b/opendbc/car/logreader.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +import os +import capnp +import urllib.parse +import warnings +from urllib.request import urlopen +import zstandard as zstd + +from opendbc.car.common.basedir import BASEDIR + +capnp_log = capnp.load(os.path.join(BASEDIR, "rlog.capnp")) + + +def decompress_stream(data: bytes): + dctx = zstd.ZstdDecompressor() + decompressed_data = b"" + + with dctx.stream_reader(data) as reader: + decompressed_data = reader.read() + + return decompressed_data + + +class LogReader: + def __init__(self, fn, only_union_types=False, sort_by_time=False): + self._only_union_types = only_union_types + _, ext = os.path.splitext(urllib.parse.urlparse(fn).path) + + if fn.startswith("http"): + with urlopen(fn) as f: + dat = f.read() + else: + with open(fn, "rb") as f: + dat = f.read() + + if ext == ".zst" or dat.startswith(b'\x28\xB5\x2F\xFD'): + # https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#zstandard-frames + dat = decompress_stream(dat) + + ents = capnp_log.Event.read_multiple_bytes(dat) + + self._ents = [] + try: + for e in ents: + self._ents.append(e) + except capnp.KjException: + warnings.warn("Corrupted events detected", RuntimeWarning, stacklevel=1) + + if sort_by_time: + self._ents.sort(key=lambda x: x.logMonoTime) + + def __iter__(self): + for ent in self._ents: + if self._only_union_types: + try: + ent.which() + yield ent + except capnp.lib.capnp.KjException: + pass + else: + yield ent + + def filter(self, msg_type: str): + return (getattr(m, m.which()) for m in filter(lambda m: m.which() == msg_type, self)) + + def first(self, msg_type: str): + return next(self.filter(msg_type), None) diff --git a/opendbc/car/rlog.capnp b/opendbc/car/rlog.capnp new file mode 100644 index 00000000000..50cdd68b5eb --- /dev/null +++ b/opendbc/car/rlog.capnp @@ -0,0 +1,23 @@ +@0xce500edaaae36b0e; + +# Minimal schema for parsing rlog CAN messages +# Subset of cereal/log.capnp + +struct CanData { + address @0 :UInt32; + busTimeDEPRECATED @1 :UInt16; + dat @2 :Data; + src @3 :UInt8; +} + +struct Event { + logMonoTime @0 :UInt64; + + union { + initData @1 :Void; + frame @2 :Void; + gpsNMEA @3 :Void; + sensorEventDEPRECATED @4 :Void; + can @5 :List(CanData); + } +} diff --git a/opendbc/car/tests/car_diff.py b/opendbc/car/tests/car_diff.py new file mode 100644 index 00000000000..a75b8f95dff --- /dev/null +++ b/opendbc/car/tests/car_diff.py @@ -0,0 +1,314 @@ +import argparse +import os +import pickle +import re +import subprocess +import sys +import tempfile +import zstandard as zstd +from urllib.request import urlopen +from collections import defaultdict +from concurrent.futures import ProcessPoolExecutor +from pathlib import Path + +from comma_car_segments import get_comma_car_segments_database, get_url + +from opendbc.car.logreader import LogReader, decompress_stream + + +TOLERANCE = 1e-4 +DIFF_BUCKET = "car_diff" +IGNORE_FIELDS = ["cumLagMs", "canErrorCounter"] + + +def dict_diff(d1, d2, path="", ignore=None, tolerance=0): + ignore = ignore or [] + diffs = [] + for key in d1.keys() | d2.keys(): + if key in ignore: + continue + full_path = f"{path}.{key}" if path else key + v1, v2 = d1.get(key), d2.get(key) + if isinstance(v1, dict) and isinstance(v2, dict): + diffs.extend(dict_diff(v1, v2, full_path, ignore, tolerance)) + elif isinstance(v1, (int, float)) and isinstance(v2, (int, float)): + if abs(v1 - v2) > tolerance: + diffs.append(("change", full_path, (v1, v2))) + elif v1 != v2: + diffs.append(("change", full_path, (v1, v2))) + return diffs + + +def load_can_messages(seg): + parts = seg.split("/") + url = get_url(f"{parts[0]}/{parts[1]}", parts[2]) + msgs = LogReader(url, only_union_types=True) + return [m for m in msgs if m.which() == 'can'] + + +def replay_segment(platform, can_msgs): + from opendbc.car import gen_empty_fingerprint, structs + from opendbc.car.can_definitions import CanData + from opendbc.car.car_helpers import FRAME_FINGERPRINT, interfaces + + fingerprint = gen_empty_fingerprint() + for msg in can_msgs[:FRAME_FINGERPRINT]: + for m in msg.can: + if m.src < 64: + fingerprint[m.src][m.address] = len(m.dat) + + CarInterface = interfaces[platform] + CP = CarInterface.get_params(platform, fingerprint, [], False, False, False) + CI = CarInterface(CP) + CC = structs.CarControl().as_reader() + + states, timestamps = [], [] + for msg in can_msgs: + frames = [CanData(c.address, c.dat, c.src) for c in msg.can] + states.append(CI.update([(msg.logMonoTime, frames)])) + CI.apply(CC, msg.logMonoTime) + timestamps.append(msg.logMonoTime) + return states, timestamps + + +def process_segment(args): + platform, seg, ref_path, update = args + try: + can_msgs = load_can_messages(seg) + states, timestamps = replay_segment(platform, can_msgs) + ref_file = Path(ref_path) / f"{platform}_{seg.replace('/', '_')}.zst" + + if update: + data = list(zip(timestamps, states, strict=True)) + ref_file.write_bytes(zstd.compress(pickle.dumps(data), 10)) + return (platform, seg, [], None) + + if not ref_file.exists(): + return (platform, seg, [], "no ref") + + ref = pickle.loads(decompress_stream(ref_file.read_bytes())) + diffs = [] + for i, ((ts, ref_state), state) in enumerate(zip(ref, states, strict=True)): + for diff in dict_diff(ref_state.to_dict(), state.to_dict(), ignore=IGNORE_FIELDS, tolerance=TOLERANCE): + diffs.append((diff[1], i, diff[2], ts)) + return (platform, seg, diffs, None) + except Exception as e: + return (platform, seg, [], str(e)) + + +def get_changed_platforms(cwd, database, interfaces): + git_ref = os.environ.get("GIT_REF", "origin/master") + changed = subprocess.check_output(["git", "diff", "--name-only", f"{git_ref}...HEAD"], cwd=cwd, encoding='utf8').strip() + brands = set() + patterns = [r"opendbc/car/(\w+)/", r"opendbc/dbc/(\w+?)_", r"opendbc/dbc/generator/(\w+)", r"opendbc/safety/modes/(\w+?)[_.]"] + for line in changed.splitlines(): + for pattern in patterns: + m = re.search(pattern, line) + if m: + brands.add(m.group(1).lower()) + return [p for p in interfaces if any(b in p.lower() for b in brands) and p in database] + + +def download_refs(ref_path, platforms, segments): + base_url = f"https://raw.githubusercontent.com/commaai/ci-artifacts/refs/heads/{DIFF_BUCKET}" + for platform in platforms: + for seg in segments.get(platform, []): + filename = f"{platform}_{seg.replace('/', '_')}.zst" + try: + with urlopen(f"{base_url}/{filename}") as resp: + (Path(ref_path) / filename).write_bytes(resp.read()) + except Exception: + pass + + +def run_replay(platforms, segments, ref_path, update, workers=4): + work = [(platform, seg, ref_path, update) + for platform in platforms for seg in segments.get(platform, [])] + with ProcessPoolExecutor(max_workers=workers) as pool: + return list(pool.map(process_segment, work)) + + +# ASCII waveforms helpers +def find_edges(vals, init): + rises = [] + falls = [] + prev = init + for i, val in enumerate(vals): + if val and not prev: + rises.append(i) + if not val and prev: + falls.append(i) + prev = val + return rises, falls + + +def render_waveform(label, vals, init): + wave = {(False, False): "_", (True, True): "‾", (False, True): "/", (True, False): "\\"} + line = f" {label}:".ljust(12) + prev = init + for val in vals: + line += wave[(prev, val)] + prev = val + if len(line) > 80: + line = line[:80] + "..." + return line + + +def format_timing(edge_type, master_edges, pr_edges, ms_per_frame): + if not master_edges or not pr_edges: + return None + delta = pr_edges[0] - master_edges[0] + if delta == 0: + return None + direction = "lags" if delta > 0 else "leads" + ms = int(abs(delta) * ms_per_frame) + return " " * 12 + f"{edge_type}: PR {direction} by {abs(delta)} frames ({ms}ms)" + + +def group_frames(diffs, max_gap=15): + groups = [] + current = [diffs[0]] + for diff in diffs[1:]: + _, frame, _, _ = diff + _, prev_frame, _, _ = current[-1] + if frame <= prev_frame + max_gap: + current.append(diff) + else: + groups.append(current) + current = [diff] + groups.append(current) + return groups + + +def build_signals(group): + _, first_frame, _, _ = group[0] + _, last_frame, (final_master, _), _ = group[-1] + start = max(0, first_frame - 5) + end = last_frame + 6 + init = not final_master + diff_at = {frame: (m, p) for _, frame, (m, p), _ in group} + master_vals = [] + pr_vals = [] + master = init + pr = init + for frame in range(start, end): + if frame in diff_at: + master, pr = diff_at[frame] + elif frame > last_frame: + master = pr = final_master + master_vals.append(master) + pr_vals.append(pr) + return master_vals, pr_vals, init, start, end + + +def format_numeric_diffs(diffs): + lines = [] + for _, frame, (old_val, new_val), _ in diffs[:10]: + lines.append(f" frame {frame}: {old_val} -> {new_val}") + if len(diffs) > 10: + lines.append(f" (... {len(diffs) - 10} more)") + return lines + + +def format_boolean_diffs(diffs): + _, first_frame, _, first_ts = diffs[0] + _, last_frame, _, last_ts = diffs[-1] + frame_time = last_frame - first_frame + time_ms = (last_ts - first_ts) / 1e6 + ms = time_ms / frame_time if frame_time else 10.0 + lines = [] + for group in group_frames(diffs): + master_vals, pr_vals, init, start, end = build_signals(group) + master_rises, master_falls = find_edges(master_vals, init) + pr_rises, pr_falls = find_edges(pr_vals, init) + if bool(master_rises) != bool(pr_rises) or bool(master_falls) != bool(pr_falls): + continue + lines.append(f"\n frames {start}-{end - 1}") + lines.append(render_waveform("master", master_vals, init)) + lines.append(render_waveform("PR", pr_vals, init)) + for edge_type, master_edges, pr_edges in [("rise", master_rises, pr_rises), ("fall", master_falls, pr_falls)]: + msg = format_timing(edge_type, master_edges, pr_edges, ms) + if msg: + lines.append(msg) + return lines + + +def format_diff(diffs): + if not diffs: + return [] + _, _, (old, new), _ = diffs[0] + is_bool = isinstance(old, bool) and isinstance(new, bool) + if is_bool: + return format_boolean_diffs(diffs) + return format_numeric_diffs(diffs) + + +def main(platform=None, segments_per_platform=10, update_refs=False, all_platforms=False): + from opendbc.car.car_helpers import interfaces + + cwd = Path(__file__).resolve().parents[3] + ref_path = cwd / DIFF_BUCKET + if not update_refs: + ref_path = Path(tempfile.mkdtemp()) + ref_path.mkdir(exist_ok=True) + database = get_comma_car_segments_database() + + if all_platforms: + print("Running all platforms...") + platforms = [p for p in interfaces if p in database] + elif platform and platform in interfaces: + platforms = [platform] + else: + platforms = get_changed_platforms(cwd, database, interfaces) + + if not platforms: + print("No car changes detected", file=sys.stderr) + return 0 + + segments = {p: database.get(p, [])[:segments_per_platform] for p in platforms} + n_segments = sum(len(s) for s in segments.values()) + print(f"{'Generating' if update_refs else 'Testing'} {n_segments} segments for: {', '.join(platforms)}") + + if update_refs: + results = run_replay(platforms, segments, ref_path, update=True) + errors = [e for _, _, _, e in results if e] + assert len(errors) == 0, f"Segment failures: {errors}" + print(f"Generated {n_segments} refs to {ref_path}") + return 0 + + download_refs(ref_path, platforms, segments) + results = run_replay(platforms, segments, ref_path, update=False) + + with_diffs = [(p, s, d) for p, s, d, e in results if d] + errors = [(p, s, e) for p, s, d, e in results if e] + n_passed = len(results) - len(with_diffs) - len(errors) + + print(f"\nResults: {n_passed} passed, {len(with_diffs)} with diffs, {len(errors)} errors") + + for plat, seg, err in errors: + print(f"\nERROR {plat} - {seg}: {err}") + + if with_diffs: + print("```") + for plat, seg, diffs in with_diffs: + print(f"\n{plat} - {seg}") + by_field = defaultdict(list) + for d in diffs: + by_field[d[0]].append(d) + for field, fd in sorted(by_field.items()): + print(f" {field} ({len(fd)} diffs)") + for line in format_diff(fd): + print(line) + print("```") + + return 1 if errors else 0 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--platform", help="diff single platform") + parser.add_argument("--segments-per-platform", type=int, default=10, help="number of segments to diff per platform") + parser.add_argument("--update-refs", action="store_true", help="update refs based on current commit") + parser.add_argument("--all", action="store_true", help="run diff on all platforms") + args = parser.parse_args() + sys.exit(main(args.platform, args.segments_per_platform, args.update_refs, args.all)) diff --git a/pyproject.toml b/pyproject.toml index 1a0ace169d8..d0d3cd27065 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ [project.optional-dependencies] testing = [ + "comma-car-segments @ https://huggingface.co/datasets/commaai/commaCarSegments/resolve/main/dist/comma_car_segments-0.1.0-py3-none-any.whl", "cffi", "gcovr", # FIXME: pytest 9.0.0 doesn't support unittest.SkipTest @@ -32,6 +33,7 @@ testing = [ "pytest-subtests", "hypothesis==6.47.*", "parameterized>=0.8,<0.9", + "zstandard", # static analysis "ruff",