|
13 | 13 |
|
14 | 14 | _SUMMARY_STATS = ['min', 'max', 'mean', 'std'] |
15 | 15 |
|
| 16 | +def _summary_stats(data: pd.Series | pd.DataFrame) -> pd.Series | pd.DataFrame: |
| 17 | + stats = data.describe().loc[_SUMMARY_STATS] |
| 18 | + stats.loc['std'] = data.std(ddof=0) |
| 19 | + return stats |
16 | 20 |
|
17 | 21 | class _GPUQuerier(abc.ABC): |
18 | 22 | command = None |
@@ -348,7 +352,8 @@ def load_code_block_names(self) -> list[str]: |
348 | 352 | return sorted(self.timestamps.code_block_name.unique()) |
349 | 353 |
|
350 | 354 | def _overall_timepoint_results(self, fields: list[str]) -> pd.DataFrame: |
351 | | - return self.timepoints[fields].describe().loc[_SUMMARY_STATS].T |
| 355 | + stats = _summary_stats(self.timepoints[fields]) |
| 356 | + return stats.T |
352 | 357 |
|
353 | 358 |
|
354 | 359 | class _SQLiteDataProxy(_DataProxy): |
@@ -441,21 +446,27 @@ def load_timepoints(self, timestamp_pairs: list[tuple[float, float]]) -> pd.Data |
441 | 446 | return self._read_sql(sql) |
442 | 447 |
|
443 | 448 | def _overall_timepoint_results(self, fields: list[str]) -> pd.DataFrame: |
444 | | - sql = 'SELECT\n' |
445 | | - std_func = 'sqrt((sum({0} * {0}) - (sum({0}) * sum({0})) / count({0})) / count({0})) AS "STDDEV({0})"' |
446 | | - sql_funcs = 'MIN', 'MAX', 'AVG', 'STDDEV' |
447 | | - field_aggregates = list[str]() |
448 | | - for func in sql_funcs: |
449 | | - for field in fields: |
450 | | - aggregate = f'{func}({field})' if func != 'STDDEV' else std_func.format(field) |
451 | | - field_aggregates.append(aggregate) |
452 | | - sql += ',\n'.join(field_aggregates) |
453 | | - sql += f'\nFROM {_SQLiteDataProxy._DATA_TABLE}' |
| 449 | + cte_blocks = list[str]() |
| 450 | + selects = list[str]() |
| 451 | + for col in fields: |
| 452 | + mean_cte = f'mean_{col}' |
| 453 | + diff_cte = f'diff_{col}' |
| 454 | + cte_blocks.append(f'{mean_cte} AS (SELECT AVG({col}) AS mean FROM {_SQLiteDataProxy._DATA_TABLE})') |
| 455 | + cte_blocks.append( |
| 456 | + f'{diff_cte} AS (SELECT {col} - (SELECT mean FROM {mean_cte}) AS diff FROM {_SQLiteDataProxy._DATA_TABLE})' |
| 457 | + ) |
| 458 | + selects.append(f'MIN({col})') |
| 459 | + selects.append(f'MAX({col})') |
| 460 | + selects.append(f'(SELECT mean FROM {mean_cte}) AS "AVG({col})"') |
| 461 | + selects.append(f'(SELECT SQRT(AVG(diff * diff)) FROM {diff_cte}) AS "STDDEV({col})"') |
| 462 | + with_clause = "WITH " + ",\n ".join(cte_blocks) |
| 463 | + select_clause = "SELECT " + ",\n ".join(selects) |
| 464 | + sql = f"{with_clause}\n{select_clause} FROM {_SQLiteDataProxy._DATA_TABLE};" |
454 | 465 | results = self._read_sql(sql).squeeze() |
455 | 466 | reshaped_results = pd.DataFrame() |
456 | | - n_fields = len(fields) |
457 | | - for i, sql_func, index in zip(range(0, len(results), n_fields), sql_funcs, _SUMMARY_STATS): |
458 | | - next_row = results.iloc[i: i + n_fields] |
| 467 | + sql_funcs = 'MIN', 'MAX', 'AVG', 'STDDEV' |
| 468 | + for sql_func, index in zip(sql_funcs, _SUMMARY_STATS): |
| 469 | + next_row = results.loc[[idx.startswith(sql_func) for idx in results.index]] |
459 | 470 | next_row.index = [col.replace(sql_func, '').replace('(', '').replace(')', '') for col in next_row.index] |
460 | 471 | reshaped_results.loc[:, index] = next_row |
461 | 472 | return reshaped_results |
|
0 commit comments