Skip to content

Commit 8e6956d

Browse files
committed
final
Signed-off-by: Yang Wang <[email protected]>
1 parent ab6e6cf commit 8e6956d

File tree

1 file changed

+0
-182
lines changed

1 file changed

+0
-182
lines changed

.ci/scripts/benchmark_tooling/get_benchmark_analysis_data.py

Lines changed: 0 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -139,188 +139,6 @@ def argparser():
139139

140140
return parser.parse_args()
141141

142-
143-
class BenchmarkFetcher:
144-
def __init__(
145-
self,
146-
env="prod",
147-
repo="pytorch/pytorch",
148-
benchmark_name="",
149-
disable_logging=False,
150-
group_table_fields=None,
151-
group_row_fields=None,
152-
processor_funcs=None,
153-
):
154-
"""
155-
Initialize the ExecutorchBenchmarkFetcher.
156-
157-
Args:
158-
env: Environment to use ("local" or "prod")
159-
disable_logging: Whether to suppress log output
160-
group_table_fields: Custom fields to group tables by (defaults to device, backend, arch, model)
161-
group_row_fields: Custom fields to group rows by (defaults to workflow_id, job_id, granularity_bucket)
162-
"""
163-
self.env = env
164-
self.base_url = self._get_base_url()
165-
self.query_group_table_by_fields = (
166-
group_table_fields
167-
if group_table_fields
168-
else ["device", "backend", "arch", "model"]
169-
)
170-
self.query_group_row_by_fields = (
171-
group_row_fields
172-
if group_row_fields
173-
else ["workflow_id", "job_id", "granularity_bucket"]
174-
)
175-
self.data = None
176-
self.disable_logging = disable_logging
177-
self.processor_funcs = processor_funcs if processor_funcs else []
178-
179-
self.repo = repo
180-
self.benchmark_name = benchmark_name
181-
182-
def _get_base_url(self) -> str:
183-
"""
184-
Get the base URL for API requests based on environment.
185-
186-
Returns:
187-
Base URL string for the configured environment
188-
"""
189-
base_urls = {
190-
"local": "http://localhost:3000",
191-
"prod": "https://hud.pytorch.org",
192-
}
193-
return base_urls[self.env]
194-
195-
def _fetch_data(
196-
self, start_time: str, end_time: str
197-
) -> Optional[List[Dict[str, Any]]]:
198-
"""
199-
Fetch and process benchmark data for the specified time range.
200-
201-
Args:
202-
start_time: ISO8601 formatted start time
203-
end_time: ISO8601 formatted end time
204-
205-
Returns:
206-
Processed benchmark data or None if fetch failed
207-
"""
208-
data = self._fetch_data(start_time, end_time)
209-
if data is None:
210-
return None
211-
self.data = self._process(data)
212-
return self.data
213-
214-
def to_df(self) -> Any:
215-
if not self.data:
216-
return
217-
dfs = [
218-
{"groupInfo": item["groupInfo"], "df": pd.DataFrame(item["rows"])}
219-
for item in self.data
220-
]
221-
return dfs
222-
223-
def _process(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
224-
"""
225-
Process raw benchmark data.
226-
227-
This method:
228-
1. Normalizes string values in groupInfo
229-
2. Creates table_name from group info components
230-
3. Determines aws_type (public/private) based on device name
231-
4. Sorts results by table_name
232-
233-
Args:
234-
data: Raw benchmark data from API
235-
236-
Returns:
237-
Processed benchmark data
238-
"""
239-
for item in data:
240-
# normalized string values in groupInfo
241-
item["groupInfo"] = {
242-
k: self.normalize_string(v)
243-
for k, v in item.get("groupInfo", {}).items()
244-
if v is not None and isinstance(v, str)
245-
}
246-
group = item.get("groupInfo", {})
247-
name = self._generate_table_name(group, self.query_group_table_by_fields)
248-
249-
# Add full name joined by the group key fields
250-
item["table_name"] = name
251-
252-
# Mark aws_type: private or public
253-
if group.get("device", "").find("private") != -1:
254-
item["groupInfo"]["aws_type"] = "private"
255-
else:
256-
item["groupInfo"]["aws_type"] = "public"
257-
258-
data.sort(key=lambda x: x["table_name"])
259-
logging.info(f"fetched {len(data)} table views")
260-
return data
261-
262-
def normalize_string(self, s, replace="_"):
263-
return s.lower().replace(" ", replace)
264-
265-
def _generate_table_name(self, group_info: dict, fields: list[str]) -> str:
266-
name = "|".join(
267-
group_info[k] for k in fields if k in group_info and group_info[k]
268-
)
269-
return self.normalize_string(name)
270-
271-
def _call_api(self, start_time, end_time):
272-
url = f"{self.base_url}/api/benchmark/group_data"
273-
params_object = BenchmarkQueryGroupDataParams(
274-
repo="pytorch/executorch",
275-
benchmark_name="ExecuTorch",
276-
start_time=start_time,
277-
end_time=end_time,
278-
group_table_by_fields=self.query_group_table_by_fields,
279-
group_row_by_fields=self.query_group_row_by_fields,
280-
)
281-
params = {k: v for k, v in params_object.__dict__.items() if v is not None}
282-
response = requests.get(url, params=params)
283-
if response.status_code == 200:
284-
return response.json()
285-
else:
286-
logging.info(f"Failed to fetch benchmark data ({response.status_code})")
287-
logging.info(response.text)
288-
return None
289-
290-
def _write_multiple_csv_files(
291-
self, data_list: List[Dict[str, Any]], output_dir: str, file_prefix=""
292-
) -> None:
293-
"""
294-
Write multiple benchmark results to separate CSV files.
295-
296-
Each entry in `data_list` becomes its own CSV file.
297-
298-
Args:
299-
data_list: List of benchmark result dictionaries
300-
output_dir: Directory to save the CSV files
301-
"""
302-
os.makedirs(output_dir, exist_ok=True)
303-
logging.info(
304-
f"\n ========= Generating multiple CSV files in {output_dir} ========= \n"
305-
)
306-
for idx, entry in enumerate(data_list):
307-
file_name = entry.get("short_name", f"file{idx+1}")
308-
309-
if file_prefix:
310-
file_name = file_prefix + file_name
311-
if len(file_name) > 100:
312-
logging.warning(
313-
f"File name '{file_name}' is too long, truncating to 100 characters"
314-
)
315-
file_name = file_name[:100]
316-
file_path = os.path.join(output_dir, f"{file_name}.csv")
317-
318-
rows = entry.get("rows", [])
319-
logging.info(f"Writing CSV: {file_path} with {len(rows)} rows")
320-
df = pd.DataFrame(rows)
321-
df.to_csv(file_path, index=False)
322-
323-
324142
class ExecutorchBenchmarkFetcher:
325143
"""
326144
Fetch and process benchmark data from HUD API for ExecutorchBenchmark.

0 commit comments

Comments
 (0)