|
1 | 1 | #!/usr/bin/env python3 |
2 | | - |
3 | 2 | # MIT License |
4 | 3 | # |
5 | 4 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. |
|
22 | 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN |
23 | 22 | # THE SOFTWARE. |
24 | 23 |
|
25 | | - |
26 | | -import os |
27 | 24 | import sys |
28 | | -import csv |
29 | | -import subprocess |
30 | 25 | import pytest |
31 | 26 |
|
32 | 27 |
|
33 | | -def run_rocpd_convert(db_path, out_dir): |
34 | | - """Convert rocpd database to CSV format.""" |
35 | | - os.makedirs(out_dir, exist_ok=True) |
36 | | - cmd = [sys.executable, "-m", "rocpd", "convert", "-i", db_path, "--output-format", "csv", "-d", out_dir] |
37 | | - res = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) |
38 | | - assert res.returncode == 0, f"rocpd convert failed\ncmd={' '.join(cmd)}\nstdout={res.stdout}\nstderr={res.stderr}" |
39 | | - |
40 | | - |
41 | | -def find_kernel_trace_csv(out_dir): |
42 | | - """Locate kernel_trace CSV file in output directory.""" |
43 | | - for fn in os.listdir(out_dir): |
44 | | - if fn.endswith("kernel_trace.csv"): |
45 | | - return os.path.join(out_dir, fn) |
46 | | - assert False, f"kernel trace CSV not found in {out_dir}" |
47 | | - |
48 | | - |
49 | | -def load_csv_rows(path): |
50 | | - """Load CSV file and return rows as list of dicts.""" |
51 | | - assert os.path.isfile(path), f"missing CSV: {path}" |
52 | | - with open(path, newline="") as f: |
53 | | - reader = csv.DictReader(f) |
54 | | - rows = list(reader) |
55 | | - assert len(rows) > 0, f"empty CSV: {path}" |
56 | | - return rows |
57 | | - |
58 | | - |
59 | 28 | def extract_json_kernel_records(json_root): |
60 | 29 | """Extract kernel dispatch records from JSON output.""" |
61 | 30 | assert "rocprofiler-sdk-tool" in json_root, "missing rocprofiler-sdk-tool in JSON" |
62 | 31 | tool = json_root["rocprofiler-sdk-tool"] |
63 | 32 | if isinstance(tool, list) and len(tool) > 0: |
64 | 33 | tool = tool[0] |
| 34 | + |
65 | 35 | assert "buffer_records" in tool, "missing buffer_records in JSON" |
66 | 36 | br = tool["buffer_records"] |
67 | | - |
| 37 | + |
68 | 38 | for key in ("kernel_dispatch", "kernel_trace", "kernel_dispatch_trace"): |
69 | 39 | if key in br and isinstance(br[key], list) and len(br[key]) > 0: |
70 | 40 | return br[key] |
| 41 | + |
71 | 42 | assert False, f"no kernel dispatch records found in JSON buffer_records keys={list(br.keys())}" |
72 | 43 |
|
73 | 44 |
|
| 45 | +def _as_int(val, *, field="value"): |
| 46 | + assert val is not None, f"missing {field}" |
| 47 | + try: |
| 48 | + return int(val) |
| 49 | + except Exception as e: |
| 50 | + raise AssertionError(f"failed to parse int for {field}: {val!r} ({e})") from e |
| 51 | + |
| 52 | + |
| 53 | +def _extract_dispatch_id_from_json_record(r): |
| 54 | + """ |
| 55 | + Prefer dispatch_info.dispatch_id. |
| 56 | + Fallback to correlation_id.internal (or correlation_id if scalar). |
| 57 | + Return int. |
| 58 | + """ |
| 59 | + dispatch_info = r.get("dispatch_info", {}) |
| 60 | + dispatch_id = None |
| 61 | + if isinstance(dispatch_info, dict): |
| 62 | + dispatch_id = dispatch_info.get("dispatch_id", None) |
| 63 | + |
| 64 | + if dispatch_id is None: |
| 65 | + corr_id = r.get("correlation_id", {}) |
| 66 | + if isinstance(corr_id, dict): |
| 67 | + dispatch_id = corr_id.get("internal", None) |
| 68 | + else: |
| 69 | + dispatch_id = corr_id |
| 70 | + |
| 71 | + return _as_int(dispatch_id, field="dispatch_id/correlation_id") |
| 72 | + |
| 73 | + |
74 | 74 | def build_json_duration_map(records): |
75 | | - """Build map of dispatch_id -> (start, end, duration) from JSON records.""" |
| 75 | + """ |
| 76 | + Build map: |
| 77 | + key(int dispatch_id) -> (start, end, duration) |
| 78 | + """ |
76 | 79 | m = {} |
77 | 80 | for r in records: |
78 | | - # Extract dispatch ID |
79 | | - dispatch_info = r.get("dispatch_info", {}) |
80 | | - dispatch_id = dispatch_info.get("dispatch_id") if isinstance(dispatch_info, dict) else None |
81 | | - |
82 | | - # Fallback to correlation_id if no dispatch_id |
83 | | - if dispatch_id is None: |
84 | | - corr_id = r.get("correlation_id", {}) |
85 | | - if isinstance(corr_id, dict): |
86 | | - dispatch_id = corr_id.get("internal", 0) |
87 | | - else: |
88 | | - dispatch_id = corr_id |
89 | | - |
90 | | - # Extract timestamps |
91 | | - start = r.get("start_timestamp") |
92 | | - end = r.get("end_timestamp") |
93 | | - assert start is not None and end is not None, f"missing timestamps in JSON record: {r}" |
94 | | - |
95 | | - start = int(start) |
96 | | - end = int(end) |
| 81 | + did = _extract_dispatch_id_from_json_record(r) |
| 82 | + |
| 83 | + start = _as_int(r.get("start_timestamp"), field="start_timestamp") |
| 84 | + end = _as_int(r.get("end_timestamp"), field="end_timestamp") |
| 85 | + |
97 | 86 | assert start > 0 and end > 0, f"invalid timestamps start={start} end={end}" |
98 | 87 | assert end >= start, f"end before start: start={start} end={end}" |
99 | | - |
100 | | - duration = end - start |
101 | | - m[str(dispatch_id)] = (start, end, duration) |
102 | | - |
| 88 | + |
| 89 | + m[did] = (start, end, end - start) |
| 90 | + |
103 | 91 | assert len(m) > 0, "no kernel records extracted from JSON" |
104 | 92 | return m |
105 | 93 |
|
106 | 94 |
|
107 | | -def test_rocpd_kernel_trace_duration(json_data, db_path, tmp_path): |
| 95 | +def load_kernel_rows_via_rocpd(db_path): |
| 96 | + """ |
| 97 | + Use rocpd Python API to query the same underlying data used by rocpd/csv.py::write_kernel_csv(). |
| 98 | + Returns list[dict]. |
| 99 | + """ |
| 100 | + try: |
| 101 | + import rocpd |
| 102 | + except Exception as e: |
| 103 | + raise AssertionError( |
| 104 | + f"failed to import rocpd python module. Ensure PYTHONPATH is set for rocprofiler-sdk build tree. ({e})" |
| 105 | + ) from e |
| 106 | + |
| 107 | + # RocpdImportData can take a list of inputs |
| 108 | + data = rocpd.connect([db_path]) |
| 109 | + |
| 110 | + # Minimal columns required for strict consistency checks |
| 111 | + # NOTE: rocpd/csv.py::write_kernel_csv selects from "kernels" |
| 112 | + query = """ |
| 113 | + SELECT |
| 114 | + dispatch_id AS Dispatch_Id, |
| 115 | + stack_id AS Correlation_Id, |
| 116 | + start AS Start_Timestamp, |
| 117 | + end AS End_Timestamp, |
| 118 | + (end - start) AS Duration |
| 119 | + FROM "kernels" |
| 120 | + ORDER BY |
| 121 | + guid ASC, start ASC, end DESC |
| 122 | + """ |
| 123 | + |
| 124 | + cur = rocpd.execute(data, query) |
| 125 | + cols = [d[0] for d in cur.description] |
| 126 | + rows = [dict(zip(cols, r)) for r in cur.fetchall()] |
| 127 | + |
| 128 | + assert len(rows) > 0, f"no rows returned from kernels table in db: {db_path}" |
| 129 | + return rows |
| 130 | + |
| 131 | + |
| 132 | +def test_rocpd_kernel_trace_duration(json_data, db_path): |
108 | 133 | """ |
109 | | - Test that rocpd CSV output contains Duration column and values match JSON. |
110 | | - |
111 | | - Test strategy: |
112 | | - 1. Generate JSON and rocpd output from SAME execution (using ROCPROF_OUTPUT_FORMAT env var) |
113 | | - 2. Use rocpd to convert database to CSV |
114 | | - 3. Compare CSV Duration with JSON-derived duration |
115 | | - |
116 | | - Since JSON and rocpd come from the same execution, timestamps should be IDENTICAL. |
117 | | - We expect ZERO tolerance for differences. |
118 | | - |
119 | | - Validates: |
120 | | - - Duration column exists in CSV |
121 | | - - Duration values EXACTLY match between JSON and CSV (zero tolerance) |
122 | | - - Duration correctly calculated as End - Start |
123 | | - - Start and End timestamps also match exactly |
| 134 | + Test that rocpd DB content for kernel trace has Duration and it matches JSON derived durations. |
| 135 | +
|
| 136 | + Strategy: |
| 137 | + - JSON and rocpd DB are generated from the SAME rocprofv3 execution (ROCPROF_OUTPUT_FORMAT=json,rocpd) |
| 138 | + - Read kernel records from JSON |
| 139 | + - Read kernel rows from rocpd DB using rocpd Python API (no CSV I/O) |
| 140 | + - Enforce: |
| 141 | + * DB Duration == End - Start |
| 142 | + * DB Start/End/Duration EXACTLY match JSON for each dispatch_id (zero tolerance) |
| 143 | + * All kernel rows in DB match to a JSON record |
124 | 144 | """ |
125 | | - # Convert rocpd DB to CSV |
126 | | - out_dir = tmp_path / "rocpd_csv" |
127 | | - run_rocpd_convert(db_path, str(out_dir)) |
128 | | - csv_path = find_kernel_trace_csv(str(out_dir)) |
129 | | - csv_rows = load_csv_rows(csv_path) |
130 | | - |
131 | | - # Verify Duration column exists |
132 | | - assert "Duration" in csv_rows[0], f"missing 'Duration' column; columns={list(csv_rows[0].keys())}" |
133 | | - |
134 | | - # Extract JSON data |
| 145 | + # Load DB rows via rocpd Python API |
| 146 | + db_rows = load_kernel_rows_via_rocpd(db_path) |
| 147 | + |
| 148 | + # Build JSON dispatch_id -> (start,end,dur) |
135 | 149 | json_records = extract_json_kernel_records(json_data) |
136 | 150 | json_map = build_json_duration_map(json_records) |
137 | | - |
138 | | - # Track statistics |
| 151 | + |
| 152 | + total_count = len(db_rows) |
139 | 153 | matched_count = 0 |
140 | | - total_count = len(csv_rows) |
141 | 154 | mismatches = [] |
142 | | - |
143 | | - for csv_row in csv_rows: |
144 | | - # Get CSV values |
145 | | - csv_start = int(csv_row["Start_Timestamp"]) |
146 | | - csv_end = int(csv_row["End_Timestamp"]) |
147 | | - csv_dur = int(csv_row["Duration"]) |
148 | | - |
149 | | - # Validate CSV internal consistency |
150 | | - assert csv_start > 0 and csv_end > 0, f"invalid CSV timestamps: start={csv_start} end={csv_end}" |
151 | | - assert csv_end >= csv_start, f"CSV end before start: {csv_end} < {csv_start}" |
152 | | - assert csv_dur >= 0, f"negative CSV duration: {csv_dur}" |
153 | | - assert csv_dur == (csv_end - csv_start), f"CSV duration mismatch: {csv_dur} != {csv_end - csv_start}" |
154 | | - |
155 | | - # Match with JSON and require EXACT match (zero tolerance) |
156 | | - dispatch_id = csv_row.get("Dispatch_Id") or csv_row.get("Correlation_Id") |
157 | | - if dispatch_id and str(dispatch_id) in json_map: |
158 | | - matched_count += 1 |
159 | | - json_start, json_end, json_dur = json_map[str(dispatch_id)] |
160 | | - |
161 | | - # Check for exact match on all three values |
162 | | - start_diff = csv_start - json_start |
163 | | - end_diff = csv_end - json_end |
164 | | - dur_diff = csv_dur - json_dur |
165 | | - |
166 | | - if start_diff != 0 or end_diff != 0 or dur_diff != 0: |
167 | | - mismatches.append({ |
168 | | - 'dispatch_id': dispatch_id, |
169 | | - 'csv_start': csv_start, |
170 | | - 'json_start': json_start, |
171 | | - 'start_diff': start_diff, |
172 | | - 'csv_end': csv_end, |
173 | | - 'json_end': json_end, |
174 | | - 'end_diff': end_diff, |
175 | | - 'csv_dur': csv_dur, |
176 | | - 'json_dur': json_dur, |
177 | | - 'dur_diff': dur_diff |
178 | | - }) |
179 | | - |
180 | | - # Report any mismatches |
| 155 | + missing_in_json = [] |
| 156 | + |
| 157 | + for row in db_rows: |
| 158 | + did = _as_int(row.get("Dispatch_Id"), field="Dispatch_Id") |
| 159 | + start = _as_int(row.get("Start_Timestamp"), field="Start_Timestamp") |
| 160 | + end = _as_int(row.get("End_Timestamp"), field="End_Timestamp") |
| 161 | + dur = _as_int(row.get("Duration"), field="Duration") |
| 162 | + |
| 163 | + # DB internal consistency |
| 164 | + assert start > 0 and end > 0, f"invalid DB timestamps: start={start} end={end} dispatch_id={did}" |
| 165 | + assert end >= start, f"DB end before start: start={start} end={end} dispatch_id={did}" |
| 166 | + assert dur >= 0, f"negative DB duration: duration={dur} dispatch_id={did}" |
| 167 | + assert dur == (end - start), ( |
| 168 | + f"DB duration mismatch: duration={dur} != end-start={end - start} dispatch_id={did}" |
| 169 | + ) |
| 170 | + |
| 171 | + if did not in json_map: |
| 172 | + missing_in_json.append(did) |
| 173 | + continue |
| 174 | + |
| 175 | + matched_count += 1 |
| 176 | + j_start, j_end, j_dur = json_map[did] |
| 177 | + |
| 178 | + sd = start - j_start |
| 179 | + ed = end - j_end |
| 180 | + dd = dur - j_dur |
| 181 | + |
| 182 | + if sd != 0 or ed != 0 or dd != 0: |
| 183 | + mismatches.append( |
| 184 | + { |
| 185 | + "dispatch_id": did, |
| 186 | + "db_start": start, |
| 187 | + "json_start": j_start, |
| 188 | + "start_diff": sd, |
| 189 | + "db_end": end, |
| 190 | + "json_end": j_end, |
| 191 | + "end_diff": ed, |
| 192 | + "db_dur": dur, |
| 193 | + "json_dur": j_dur, |
| 194 | + "dur_diff": dd, |
| 195 | + } |
| 196 | + ) |
| 197 | + |
| 198 | + # Hard failures with actionable context |
| 199 | + if missing_in_json: |
| 200 | + sample = missing_in_json[:10] |
| 201 | + raise AssertionError( |
| 202 | + "Some DB kernel rows had dispatch_id not present in JSON records. " |
| 203 | + "Since JSON and rocpd come from the same execution, dispatch IDs should align.\n" |
| 204 | + f"Missing count: {len(missing_in_json)}/{total_count}\n" |
| 205 | + f"Sample missing dispatch_ids: {sample}" |
| 206 | + ) |
| 207 | + |
181 | 208 | if mismatches: |
182 | | - error_lines = [ |
| 209 | + lines = [ |
183 | 210 | "", |
184 | | - "TIMESTAMP MISMATCHES DETECTED", |
185 | | - f"{'Dispatch':<10} {'Start Diff':<12} {'End Diff':<12} {'Dur Diff':<12}", |
186 | | - "=" * 50 |
| 211 | + "TIMESTAMP/DURATION MISMATCHES DETECTED (zero tolerance)", |
| 212 | + f"{'Dispatch':<12} {'StartDiff':<12} {'EndDiff':<12} {'DurDiff':<12}", |
| 213 | + "=" * 56, |
187 | 214 | ] |
188 | | - |
189 | | - for m in mismatches[:10]: # Show first 10 |
190 | | - error_lines.append( |
191 | | - f"{m['dispatch_id']:<10} {m['start_diff']:<12} {m['end_diff']:<12} {m['dur_diff']:<12}" |
| 215 | + for m in mismatches[:10]: |
| 216 | + lines.append( |
| 217 | + f"{m['dispatch_id']:<12} {m['start_diff']:<12} {m['end_diff']:<12} {m['dur_diff']:<12}" |
192 | 218 | ) |
193 | | - |
194 | 219 | if len(mismatches) > 10: |
195 | | - error_lines.append(f"... and {len(mismatches) - 10} more mismatches") |
196 | | - |
197 | | - # Fail the test with detailed error |
| 220 | + lines.append(f"... and {len(mismatches) - 10} more mismatches") |
| 221 | + |
198 | 222 | first = mismatches[0] |
199 | | - error_msg = "\n".join(error_lines) + "\n\n" + ( |
200 | | - f"Timestamp mismatch detected for dispatch {first['dispatch_id']}:\n" |
201 | | - f" CSV: start={first['csv_start']}, end={first['csv_end']}, duration={first['csv_dur']}\n" |
| 223 | + detail = ( |
| 224 | + f"\n\nExample mismatch for dispatch {first['dispatch_id']}:\n" |
| 225 | + f" DB: start={first['db_start']}, end={first['db_end']}, duration={first['db_dur']}\n" |
202 | 226 | f" JSON: start={first['json_start']}, end={first['json_end']}, duration={first['json_dur']}\n" |
203 | 227 | f" Diff: start={first['start_diff']}, end={first['end_diff']}, duration={first['dur_diff']}\n" |
204 | 228 | f"Total mismatches: {len(mismatches)}/{total_count}\n" |
205 | | - f"NOTE: Since JSON and rocpd come from the same execution, timestamps should be identical." |
| 229 | + "NOTE: Since JSON and rocpd come from the same execution, these should be identical." |
206 | 230 | ) |
207 | | - assert False, error_msg |
208 | | - |
209 | | - # Ensure we matched all records |
210 | | - assert matched_count > 0, f"No CSV rows matched with JSON records" |
211 | | - assert matched_count == total_count, f"Only {matched_count}/{total_count} CSV rows matched JSON" |
| 231 | + raise AssertionError("\n".join(lines) + detail) |
| 232 | + |
| 233 | + assert matched_count > 0, "No DB rows matched JSON records" |
| 234 | + assert matched_count == total_count, f"Only {matched_count}/{total_count} DB rows matched JSON" |
212 | 235 |
|
213 | 236 |
|
214 | 237 | if __name__ == "__main__": |
|
0 commit comments