Skip to content

Commit 83a5125

Browse files
committed
Completes the analysis test
1 parent 3445bf5 commit 83a5125

22 files changed

+338
-70
lines changed

src/gpu_tracker/_helper_classes.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313

1414
_SUMMARY_STATS = ['min', 'max', 'mean', 'std']
1515

16+
def _summary_stats(data: pd.Series | pd.DataFrame) -> pd.Series | pd.DataFrame:
17+
stats = data.describe().loc[_SUMMARY_STATS]
18+
stats.loc['std'] = data.std(ddof=0)
19+
return stats
1620

1721
class _GPUQuerier(abc.ABC):
1822
command = None
@@ -348,7 +352,8 @@ def load_code_block_names(self) -> list[str]:
348352
return sorted(self.timestamps.code_block_name.unique())
349353

350354
def _overall_timepoint_results(self, fields: list[str]) -> pd.DataFrame:
351-
return self.timepoints[fields].describe().loc[_SUMMARY_STATS].T
355+
stats = _summary_stats(self.timepoints[fields])
356+
return stats.T
352357

353358

354359
class _SQLiteDataProxy(_DataProxy):
@@ -441,21 +446,27 @@ def load_timepoints(self, timestamp_pairs: list[tuple[float, float]]) -> pd.Data
441446
return self._read_sql(sql)
442447

443448
def _overall_timepoint_results(self, fields: list[str]) -> pd.DataFrame:
444-
sql = 'SELECT\n'
445-
std_func = 'sqrt((sum({0} * {0}) - (sum({0}) * sum({0})) / count({0})) / count({0})) AS "STDDEV({0})"'
446-
sql_funcs = 'MIN', 'MAX', 'AVG', 'STDDEV'
447-
field_aggregates = list[str]()
448-
for func in sql_funcs:
449-
for field in fields:
450-
aggregate = f'{func}({field})' if func != 'STDDEV' else std_func.format(field)
451-
field_aggregates.append(aggregate)
452-
sql += ',\n'.join(field_aggregates)
453-
sql += f'\nFROM {_SQLiteDataProxy._DATA_TABLE}'
449+
cte_blocks = list[str]()
450+
selects = list[str]()
451+
for col in fields:
452+
mean_cte = f'mean_{col}'
453+
diff_cte = f'diff_{col}'
454+
cte_blocks.append(f'{mean_cte} AS (SELECT AVG({col}) AS mean FROM {_SQLiteDataProxy._DATA_TABLE})')
455+
cte_blocks.append(
456+
f'{diff_cte} AS (SELECT {col} - (SELECT mean FROM {mean_cte}) AS diff FROM {_SQLiteDataProxy._DATA_TABLE})'
457+
)
458+
selects.append(f'MIN({col})')
459+
selects.append(f'MAX({col})')
460+
selects.append(f'(SELECT mean FROM {mean_cte}) AS "AVG({col})"')
461+
selects.append(f'(SELECT SQRT(AVG(diff * diff)) FROM {diff_cte}) AS "STDDEV({col})"')
462+
with_clause = "WITH " + ",\n ".join(cte_blocks)
463+
select_clause = "SELECT " + ",\n ".join(selects)
464+
sql = f"{with_clause}\n{select_clause} FROM {_SQLiteDataProxy._DATA_TABLE};"
454465
results = self._read_sql(sql).squeeze()
455466
reshaped_results = pd.DataFrame()
456-
n_fields = len(fields)
457-
for i, sql_func, index in zip(range(0, len(results), n_fields), sql_funcs, _SUMMARY_STATS):
458-
next_row = results.iloc[i: i + n_fields]
467+
sql_funcs = 'MIN', 'MAX', 'AVG', 'STDDEV'
468+
for sql_func, index in zip(sql_funcs, _SUMMARY_STATS):
469+
next_row = results.loc[[idx.startswith(sql_func) for idx in results.index]]
459470
next_row.index = [col.replace(sql_func, '').replace('(', '').replace(')', '') for col in next_row.index]
460471
reshaped_results.loc[:, index] = next_row
461472
return reshaped_results

src/gpu_tracker/sub_tracker.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pickle as pkl
1010
import logging as log
1111
import typing as typ
12-
from ._helper_classes import _DataProxy, _SubTrackerLog, _SUMMARY_STATS
12+
from ._helper_classes import _DataProxy, _SubTrackerLog, _SUMMARY_STATS, _summary_stats
1313

1414

1515
class SubTracker:
@@ -168,7 +168,7 @@ def sub_tracking_results(self) -> SubTrackingResults:
168168
for code_block_name in code_block_names:
169169
time_stamp_pairs = self.load_timestamp_pairs(code_block_name)
170170
time_stamp_diffs = pd.Series([stop_time - start_time for (start_time, stop_time) in time_stamp_pairs])
171-
compute_time_results = time_stamp_diffs.describe()[_SUMMARY_STATS]
171+
compute_time_results = _summary_stats(time_stamp_diffs)
172172
compute_time_results['total'] = time_stamp_diffs.sum().item()
173173
timepoints = self.load_timepoints(time_stamp_pairs)
174174
num_non_empty_calls = sum(
@@ -179,11 +179,11 @@ def sub_tracking_results(self) -> SubTrackingResults:
179179
]
180180
)
181181
timepoints = timepoints.drop(columns='timestamp')
182+
resource_usage = _summary_stats(timepoints).T
182183
code_block_results.append(
183184
CodeBlockResults(
184185
name=code_block_name, num_timepoints=len(timepoints), num_calls=len(time_stamp_pairs),
185-
num_non_empty_calls=num_non_empty_calls, compute_time=compute_time_results,
186-
resource_usage=timepoints.describe().loc[_SUMMARY_STATS].T
186+
num_non_empty_calls=num_non_empty_calls, compute_time=compute_time_results, resource_usage=resource_usage
187187
)
188188
)
189189
return SubTrackingResults(overall_results, static_data, code_block_results)
@@ -409,7 +409,10 @@ def _dict_to_str(string: str, results: dict, indent: int, no_title_keys: set[str
409409
string = _dict_to_str(string, value, indent + 1, no_title_keys)
410410
elif type(value) is pd.DataFrame:
411411
string += f'{key}:\n'
412-
with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.width', 5000):
412+
with pd.option_context(
413+
'display.max_rows', None, 'display.max_columns', None, 'display.width', 5000, 'display.float_format',
414+
lambda x: f'{float(f"{x:.8f}")}'
415+
):
413416
df_str = str(value)
414417
df_str = '\n'.join(indent_str + '\t' + line for line in df_str.splitlines())
415418
string += df_str + '\n'
@@ -418,7 +421,7 @@ def _dict_to_str(string: str, results: dict, indent: int, no_title_keys: set[str
418421
for value in value:
419422
string = _dict_to_str(string, value, indent + 1, no_title_keys) + '\n'
420423
else:
421-
value = f'{value:.4f}' if type(value) is float else value
424+
value = f'{value:.8f}' if type(value) is float else value
422425
string += f'{key}:{" " * (max_key_len - len(key))} {value}\n'
423426
return string
424427

src/gpu_tracker/tracker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565
self._is_linux = platform.system().lower() == 'linux'
6666
cannot_connect_warning = ('The {} command is installed but cannot connect to a GPU. '
6767
'The GPU RAM and GPU utilization values will remain 0.0.')
68+
self.data_proxy = _DataProxy.create(tracking_file, overwrite)
6869
if gpu_brand is None:
6970
nvidia_available = _NvidiaQuerier.is_available()
7071
nvidia_installed = nvidia_available is not None
@@ -119,7 +120,6 @@ def __init__(
119120
self._resource_usage = ResourceUsage(
120121
max_ram=max_ram, max_gpu_ram=max_gpu_ram, cpu_utilization=cpu_utilization, gpu_utilization=gpu_utilization,
121122
compute_time=compute_time)
122-
self.data_proxy = _DataProxy.create(tracking_file, overwrite)
123123
if self.data_proxy is not None:
124124
static_data = _StaticData(
125125
ram_unit, gpu_ram_unit, time_unit, max_ram.system_capacity, max_gpu_ram.system_capacity, system_core_count,
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
position,process_id,timestamp
2-
START,1234,12
3-
STOP,1234,13
2+
0,1234,12
3+
1,1234,13

tests/data/decorated-function.csv

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
position,process_id,timestamp
2-
START,1234,0
3-
STOP,1234,1
4-
START,1234,2
5-
STOP,1234,3
6-
START,1234,4
7-
STOP,1234,5
8-
START,1234,6
9-
STOP,1234,7
10-
START,1234,8
11-
STOP,1234,9
12-
START,1234,10
13-
STOP,1234,11
2+
0,1234,0
3+
1,1234,1
4+
0,1234,2
5+
1,1234,3
6+
0,1234,4
7+
1,1234,5
8+
0,1234,6
9+
1,1234,7
10+
0,1234,8
11+
1,1234,9
12+
0,1234,10
13+
1,1234,11

tests/data/sub-tracker.csv

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
position,process_id,timestamp
2-
START,1234,0
3-
STOP,1234,1
4-
START,1234,2
5-
STOP,1234,3
6-
START,1234,4
7-
STOP,1234,5
8-
START,1234,6
9-
STOP,1234,7
10-
START,1234,8
11-
STOP,1234,9
2+
0,1234,0
3+
1,1234,1
4+
0,1234,2
5+
1,1234,3
6+
0,1234,4
7+
1,1234,5
8+
0,1234,6
9+
1,1234,7
10+
0,1234,8
11+
1,1234,9

tests/data/sub-tracking-results/files-to-combine/1723811.sub.tmp.sqlite renamed to tests/data/sub-tracking-results/files-to-combine/1723811.sub-tracking.sqlite

File renamed without changes.

tests/data/sub-tracking-results/files-to-combine/1723814.sub.tmp.sqlite renamed to tests/data/sub-tracking-results/files-to-combine/1723814.sub-tracking.sqlite

File renamed without changes.

tests/data/sub-tracking-results/files-to-combine/1723815.sub.tmp.sqlite renamed to tests/data/sub-tracking-results/files-to-combine/1723815.sub-tracking.sqlite

File renamed without changes.

tests/data/sub-tracking-results/files-to-combine/main.sub.tmp.sqlite renamed to tests/data/sub-tracking-results/files-to-combine/main.sub-tracking.sqlite

File renamed without changes.

0 commit comments

Comments
 (0)