Skip to content

Commit 39026fb

Browse files
committed
Fix flatten function
1 parent 66cf70f commit 39026fb

File tree

2 files changed

+25
-17
lines changed

2 files changed

+25
-17
lines changed

sharktuner/sharktuner/candidate_ordering.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -151,34 +151,43 @@ def flatten_records(
151151
tuning_records: list[TuningRecord],
152152
) -> tuple[list[str], list[dict[str, Any]]]:
153153
"""
154-
Flatten a list of `TuningRecord` objects to CSV headers and rows
154+
Flatten a list of `TuningRecord` objects into CSV headers and rows.
155155
156156
- Each record becomes one CSV row.
157-
- Top-level attributes (e.g., `gen_id`, `benchmark_time_us`) are written as individual columns.
158-
- Nested object (i.e., `knob`) is flattened using dot notation: knob.tile_m, knob.intrinsic_mn
157+
- Top-level attributes (e.g., `gen_id`, `benchmark_time_us`) appear as individual columns.
158+
- Nested objects (e.g., `knob`) are flattened into columns like `knob.M`, `knob.tile_m`.
159+
160+
The original top-level attribute (e.g., 'knob') is removed once nesting is flattened.
159161
"""
160162
rows = []
161163
headers = []
164+
unneeded_headers = []
162165

163166
for tuning_record in tuning_records:
164167
row = {}
165-
for k, v in vars(tuning_record).items():
166-
if hasattr(v, "__dict__"):
167-
nested = vars(v)
168-
if nested:
169-
for nk, nv in nested.items():
170-
key = f"{k}.{nk}"
171-
row[key] = nv
172-
if key not in headers:
173-
headers.append(key)
174-
else:
168+
for attr, val in vars(tuning_record).items():
169+
if hasattr(val, "__dict__"):
170+
nested = vars(val)
171+
if not nested:
175172
continue
173+
unneeded_headers.append(attr)
174+
for sub_attr, sub_val in nested.items():
175+
key = f"{attr}.{sub_attr}"
176+
row[key] = sub_val
177+
if key not in headers:
178+
headers.append(key)
176179
else:
177-
row[k] = v
178-
if k not in headers and k != "knob":
179-
headers.append(k)
180+
row[attr] = val
181+
if attr not in headers:
182+
headers.append(attr)
180183
rows.append(row)
181184

185+
# Remove top-level attributes (e.g., 'knob') that were replaced by flattened nested fields.
186+
headers = [h for h in headers if h not in unneeded_headers]
187+
for row in rows:
188+
for unneeded in unneeded_headers:
189+
row.pop(unneeded, None)
190+
182191
return headers, rows
183192

184193

sharktuner/tests/candidate_ordering_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,6 @@ def test_flatten_records(
265265
"candidate_id": 0,
266266
"compile_status": False,
267267
"gen_id": 0,
268-
"knob": None,
269268
"to_benchmark": True,
270269
"to_compile": True,
271270
},

0 commit comments

Comments
 (0)