Skip to content

Commit 7859d02

Browse files
Merge pull request #131 from NessieCanCode/refactor-export_summary-for-robustness-and-readability
Refactor export_summary for configurability and validation
2 parents 9ee2e68 + f7a3d96 commit 7859d02

File tree

2 files changed

+166
-107
lines changed

2 files changed

+166
-107
lines changed

src/slurmdb.py

Lines changed: 121 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from datetime import date, datetime, timedelta
88
from calendar import monthrange
99
from itertools import product
10+
from typing import Any, Dict, List, Optional, Tuple, Union
1011

1112

1213
try:
@@ -590,46 +591,37 @@ def fetch_invoices(self, start_date=None, end_date=None):
590591
for r in rows
591592
]
592593

593-
def export_summary(self, start_time, end_time):
594-
"""Export a summary of usage and costs.
595-
596-
Rates represent a fixed cost per core-hour (for example, dollars
597-
per core-hour) and must be non-negative. ``discount`` values are
598-
fractional percentages, where ``0.2`` means a 20% discount, and
599-
they must fall between 0 and 1, inclusive. A :class:`ValueError`
600-
is raised if these constraints are violated.
601-
"""
602-
603-
usage, totals = self.aggregate_usage(start_time, end_time)
604-
summary = {
605-
'summary': {},
606-
'details': [],
607-
'daily': [],
608-
'monthly': [],
609-
'yearly': [],
610-
'invoices': [],
611-
}
612-
total_ch = 0.0
613-
total_gpu = 0.0
614-
total_cost = 0.0
615-
616-
rates_path = os.path.join(os.path.dirname(__file__), 'rates.json')
594+
def _load_rates(self, rates_file: Optional[str]) -> Dict[str, Any]:
595+
path = rates_file or os.path.join(os.path.dirname(__file__), 'rates.json')
617596
try:
618-
with open(rates_path) as fh:
619-
rates_cfg = json.load(fh)
597+
with open(path) as fh:
598+
return json.load(fh)
620599
except OSError as e:
621-
logging.warning("Unable to read rates file %s: %s", rates_path, e)
622-
rates_cfg = {}
600+
logging.warning("Unable to read rates file %s: %s", path, e)
601+
return {}
623602
except json.JSONDecodeError as e:
624-
logging.error("Failed to parse rates file %s: %s", rates_path, e)
603+
logging.error("Failed to parse rates file %s: %s", path, e)
625604
raise
605+
606+
def _validate_cluster_cores(self, resources: Dict[str, Any]) -> int:
607+
cores = resources.get('cores')
608+
if not isinstance(cores, (int, float)) or cores <= 0:
609+
raise ValueError(f"Invalid cluster core count {cores}")
610+
return int(cores)
611+
612+
def _build_account_details(
613+
self,
614+
usage: Dict[str, Dict[str, Any]],
615+
rates_cfg: Dict[str, Any],
616+
) -> Tuple[List[Dict[str, Any]], float, float, float]:
626617
default_rate = rates_cfg.get('defaultRate', 0.01)
627618
default_gpu_rate = rates_cfg.get('defaultGpuRate', 0.0)
628619
overrides = rates_cfg.get('overrides', {})
629620
historical = rates_cfg.get('historicalRates', {})
630621
gpu_historical = rates_cfg.get('historicalGpuRates', {})
631-
resources = self.cluster_resources()
632-
cluster_cores = resources.get('cores')
622+
623+
details: List[Dict[str, Any]] = []
624+
total_ch = total_gpu = total_cost = 0.0
633625

634626
for month, accounts in usage.items():
635627
base_rate = historical.get(month, default_rate)
@@ -643,23 +635,19 @@ def export_summary(self, start_time, end_time):
643635
if rate < 0:
644636
raise ValueError(f"Invalid rate {rate} for account {account}")
645637
if gpu_rate < 0:
646-
raise ValueError(
647-
f"Invalid GPU rate {gpu_rate} for account {account}"
648-
)
638+
raise ValueError(f"Invalid GPU rate {gpu_rate} for account {account}")
649639
if not 0 <= discount <= 1:
650-
raise ValueError(
651-
f"Invalid discount {discount} for account {account}"
652-
)
640+
raise ValueError(f"Invalid discount {discount} for account {account}")
653641

654642
acct_cost = vals['core_hours'] * rate + vals.get('gpu_hours', 0.0) * gpu_rate
655643
if 0 < discount < 1:
656644
acct_cost *= 1 - discount
657-
users = []
645+
users: List[Dict[str, Any]] = []
658646
for user, uvals in vals.get('users', {}).items():
659647
u_cost = uvals['core_hours'] * rate
660648
if 0 < discount < 1:
661649
u_cost *= 1 - discount
662-
jobs = []
650+
jobs: List[Dict[str, Any]] = []
663651
for job, jvals in uvals.get('jobs', {}).items():
664652
j_cost = jvals['core_hours'] * rate
665653
if 0 < discount < 1:
@@ -687,7 +675,7 @@ def export_summary(self, start_time, end_time):
687675
'jobs': jobs,
688676
}
689677
)
690-
summary['details'].append(
678+
details.append(
691679
{
692680
'account': account,
693681
'core_hours': round(vals['core_hours'], 2),
@@ -699,72 +687,120 @@ def export_summary(self, start_time, end_time):
699687
total_ch += vals['core_hours']
700688
total_gpu += vals.get('gpu_hours', 0.0)
701689
total_cost += acct_cost
702-
start_dt = (
703-
_fromisoformat(start_time)
704-
if isinstance(start_time, str)
705-
else datetime.fromtimestamp(start_time)
706-
)
707-
end_dt = (
708-
_fromisoformat(end_time)
709-
if isinstance(end_time, str)
710-
else datetime.fromtimestamp(end_time)
711-
)
712-
summary['summary'] = {
713-
'period': f"{start_dt.strftime('%Y-%m-%d')} to {end_dt.strftime('%Y-%m-%d')}",
714-
'total': round(total_cost, 2),
715-
'core_hours': round(total_ch, 2),
716-
'gpu_hours': round(total_gpu, 2),
717-
'cluster': resources,
718-
}
719-
if cluster_cores:
720-
start_date = start_dt.date()
721-
end_date = end_dt.date()
722-
current = date(start_date.year, start_date.month, 1)
723-
end_marker = date(end_date.year, end_date.month, 1)
724-
projected_revenue = 0.0
725-
while current <= end_marker:
726-
days_in_month = monthrange(current.year, current.month)[1]
727-
month_start = date(current.year, current.month, 1)
728-
month_end = date(current.year, current.month, days_in_month)
729-
overlap_start = max(month_start, start_date)
730-
overlap_end = min(month_end, end_date)
731-
if overlap_start <= overlap_end:
732-
days = (overlap_end - overlap_start).days + 1
733-
rate = historical.get(current.strftime('%Y-%m'), default_rate)
734-
projected_revenue += cluster_cores * 24 * days * rate
735-
if current.month == 12:
736-
current = date(current.year + 1, 1, 1)
737-
else:
738-
current = date(current.year, current.month + 1, 1)
739-
summary['summary']['projected_revenue'] = round(projected_revenue, 2)
740-
summary['daily'] = [
690+
return details, total_ch, total_gpu, total_cost
691+
692+
def _build_time_series(
693+
self, totals: Dict[str, Any]
694+
) -> Tuple[List[Dict[str, float]], List[Dict[str, float]], List[Dict[str, float]]]:
695+
daily = [
741696
{
742697
'date': d,
743698
'core_hours': round(totals['daily'].get(d, 0.0), 2),
744699
'gpu_hours': round(totals.get('daily_gpu', {}).get(d, 0.0), 2),
745700
}
746701
for d in sorted(set(totals['daily']) | set(totals.get('daily_gpu', {})))
747702
]
748-
summary['monthly'] = [
703+
monthly = [
749704
{
750705
'month': m,
751706
'core_hours': round(totals['monthly'].get(m, 0.0), 2),
752707
'gpu_hours': round(totals.get('monthly_gpu', {}).get(m, 0.0), 2),
753708
}
754709
for m in sorted(set(totals['monthly']) | set(totals.get('monthly_gpu', {})))
755710
]
756-
summary['yearly'] = [
711+
yearly = [
757712
{
758713
'year': y,
759714
'core_hours': round(totals['yearly'].get(y, 0.0), 2),
760715
'gpu_hours': round(totals.get('yearly_gpu', {}).get(y, 0.0), 2),
761716
}
762717
for y in sorted(set(totals['yearly']) | set(totals.get('yearly_gpu', {})))
763718
]
764-
summary['invoices'] = self.fetch_invoices(start_time, end_time)
765-
summary['partitions'] = sorted(totals.get('partitions', []))
766-
summary['accounts'] = sorted(totals.get('accounts', []))
767-
summary['users'] = sorted(totals.get('users', []))
719+
return daily, monthly, yearly
720+
721+
def _calculate_projected_revenue(
722+
self,
723+
start_dt: datetime,
724+
end_dt: datetime,
725+
cluster_cores: int,
726+
rates_cfg: Dict[str, Any],
727+
) -> float:
728+
default_rate = rates_cfg.get('defaultRate', 0.01)
729+
historical = rates_cfg.get('historicalRates', {})
730+
start_date = start_dt.date()
731+
end_date = end_dt.date()
732+
current = date(start_date.year, start_date.month, 1)
733+
end_marker = date(end_date.year, end_date.month, 1)
734+
projected_revenue = 0.0
735+
while current <= end_marker:
736+
days_in_month = monthrange(current.year, current.month)[1]
737+
month_start = date(current.year, current.month, 1)
738+
month_end = date(current.year, current.month, days_in_month)
739+
overlap_start = max(month_start, start_date)
740+
overlap_end = min(month_end, end_date)
741+
if overlap_start <= overlap_end:
742+
days = (overlap_end - overlap_start).days + 1
743+
rate = historical.get(current.strftime('%Y-%m'), default_rate)
744+
projected_revenue += cluster_cores * 24 * days * rate
745+
if current.month == 12:
746+
current = date(current.year + 1, 1, 1)
747+
else:
748+
current = date(current.year, current.month + 1, 1)
749+
return round(projected_revenue, 2)
750+
751+
def export_summary(
752+
self,
753+
start_time: Union[str, float],
754+
end_time: Union[str, float],
755+
rates_file: Optional[str] = None,
756+
) -> Dict[str, Any]:
757+
"""Export a summary of usage and costs.
758+
759+
Rates represent a fixed cost per core-hour (for example, dollars
760+
per core-hour) and must be non-negative. ``discount`` values are
761+
fractional percentages, where ``0.2`` means a 20% discount, and
762+
they must fall between 0 and 1, inclusive. A :class:`ValueError`
763+
is raised if these constraints are violated.
764+
"""
765+
766+
usage, totals = self.aggregate_usage(start_time, end_time)
767+
rates_cfg = self._load_rates(rates_file)
768+
resources = self.cluster_resources()
769+
cluster_cores = self._validate_cluster_cores(resources)
770+
details, total_ch, total_gpu, total_cost = self._build_account_details(usage, rates_cfg)
771+
772+
start_dt = (
773+
_fromisoformat(start_time)
774+
if isinstance(start_time, str)
775+
else datetime.fromtimestamp(start_time)
776+
)
777+
end_dt = (
778+
_fromisoformat(end_time)
779+
if isinstance(end_time, str)
780+
else datetime.fromtimestamp(end_time)
781+
)
782+
783+
daily, monthly, yearly = self._build_time_series(totals)
784+
summary = {
785+
'summary': {
786+
'period': f"{start_dt.strftime('%Y-%m-%d')} to {end_dt.strftime('%Y-%m-%d')}",
787+
'total': round(total_cost, 2),
788+
'core_hours': round(total_ch, 2),
789+
'gpu_hours': round(total_gpu, 2),
790+
'cluster': resources,
791+
},
792+
'details': details,
793+
'daily': daily,
794+
'monthly': monthly,
795+
'yearly': yearly,
796+
'invoices': self.fetch_invoices(start_time, end_time),
797+
'partitions': sorted(totals.get('partitions', [])),
798+
'accounts': sorted(totals.get('accounts', [])),
799+
'users': sorted(totals.get('users', [])),
800+
}
801+
summary['summary']['projected_revenue'] = self._calculate_projected_revenue(
802+
start_dt, end_dt, cluster_cores, rates_cfg
803+
)
768804
return summary
769805

770806

test/unit/billing_summary.test.py

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,22 @@ def test_export_summary_aggregates_costs(self):
2121
with mock.patch.object(
2222
SlurmDB,
2323
'aggregate_usage',
24-
return_value=(usage, {
25-
'daily': {},
26-
'monthly': {},
27-
'yearly': {},
28-
'daily_gpu': {},
29-
'monthly_gpu': {},
30-
'yearly_gpu': {},
31-
}),
24+
return_value=(
25+
usage,
26+
{
27+
'daily': {},
28+
'monthly': {},
29+
'yearly': {},
30+
'daily_gpu': {},
31+
'monthly_gpu': {},
32+
'yearly_gpu': {},
33+
},
34+
),
35+
), mock.patch.object(SlurmDB, 'fetch_invoices', return_value=invoices), mock.patch.object(
36+
SlurmDB, 'cluster_resources', return_value={'cores': 100}
3237
):
33-
with mock.patch.object(SlurmDB, 'fetch_invoices', return_value=invoices):
34-
db = SlurmDB()
35-
summary = db.export_summary('2023-10-01', '2023-10-31')
38+
db = SlurmDB()
39+
summary = db.export_summary('2023-10-01', '2023-10-31')
3640
self.assertEqual(summary['summary']['total'], 1.2)
3741
self.assertEqual(summary['details'][0]['account'], 'acct')
3842
self.assertEqual(summary['details'][0]['core_hours'], 10.0)
@@ -62,7 +66,9 @@ def test_export_summary_applies_overrides_and_discounts(self):
6266
SlurmDB,
6367
'aggregate_usage',
6468
return_value=(usage, {'daily': {}, 'monthly': {}, 'yearly': {}}),
65-
), mock.patch.object(SlurmDB, 'fetch_invoices', return_value=[]):
69+
), mock.patch.object(SlurmDB, 'fetch_invoices', return_value=[]), mock.patch.object(
70+
SlurmDB, 'cluster_resources', return_value={'cores': 100}
71+
):
6672
db = SlurmDB()
6773
summary = db.export_summary('2024-02-01', '2024-02-29')
6874
costs = {d['account']: d['cost'] for d in summary['details']}
@@ -101,7 +107,9 @@ def test_export_summary_preserves_job_details(self):
101107
SlurmDB,
102108
'aggregate_usage',
103109
return_value=(usage, {'daily': {}, 'monthly': {}, 'yearly': {}}),
104-
), mock.patch.object(SlurmDB, 'fetch_invoices', return_value=[]):
110+
), mock.patch.object(SlurmDB, 'fetch_invoices', return_value=[]), mock.patch.object(
111+
SlurmDB, 'cluster_resources', return_value={'cores': 100}
112+
):
105113
db = SlurmDB()
106114
summary = db.export_summary('2024-03-01', '2024-03-31')
107115
job = summary['details'][0]['users'][0]['jobs'][0]
@@ -133,6 +141,8 @@ def fake_open(path, *args, **kwargs):
133141
return_value=(usage, {'daily': {}, 'monthly': {}, 'yearly': {}}),
134142
), mock.patch.object(SlurmDB, 'fetch_invoices', return_value=[]), mock.patch(
135143
'builtins.open', side_effect=fake_open
144+
), mock.patch.object(
145+
SlurmDB, 'cluster_resources', return_value={'cores': 100}
136146
):
137147
db = SlurmDB()
138148
with self.assertRaises(ValueError):
@@ -157,6 +167,8 @@ def fake_open(path, *args, **kwargs):
157167
return_value=(usage, {'daily': {}, 'monthly': {}, 'yearly': {}}),
158168
), mock.patch.object(SlurmDB, 'fetch_invoices', return_value=[]), mock.patch(
159169
'builtins.open', side_effect=fake_open
170+
), mock.patch.object(
171+
SlurmDB, 'cluster_resources', return_value={'cores': 100}
160172
):
161173
db = SlurmDB()
162174
with self.assertRaises(ValueError):
@@ -181,6 +193,25 @@ def fake_open(path, *args, **kwargs):
181193
return_value=(usage, {'daily': {}, 'monthly': {}, 'yearly': {}}),
182194
), mock.patch.object(SlurmDB, 'fetch_invoices', return_value=[]), mock.patch(
183195
'builtins.open', side_effect=fake_open
196+
), mock.patch.object(
197+
SlurmDB, 'cluster_resources', return_value={'cores': 100}
198+
):
199+
db = SlurmDB()
200+
with self.assertRaises(ValueError):
201+
db.export_summary('2023-10-01', '2023-10-31')
202+
203+
def test_export_summary_invalid_cluster_cores(self):
204+
usage = {
205+
'2023-10': {
206+
'acct': {'core_hours': 10.0, 'users': {}}
207+
}
208+
}
209+
with mock.patch.object(
210+
SlurmDB,
211+
'aggregate_usage',
212+
return_value=(usage, {'daily': {}, 'monthly': {}, 'yearly': {}}),
213+
), mock.patch.object(SlurmDB, 'fetch_invoices', return_value=[]), mock.patch.object(
214+
SlurmDB, 'cluster_resources', return_value={'cores': 0}
184215
):
185216
db = SlurmDB()
186217
with self.assertRaises(ValueError):
@@ -207,15 +238,7 @@ def fake_open(path, *args, **kwargs):
207238
), mock.patch.object(SlurmDB, 'fetch_invoices', return_value=[]), mock.patch(
208239
'builtins.open', side_effect=fake_open
209240
), mock.patch.object(
210-
SlurmDB,
211-
'_parse_slurm_conf',
212-
return_value={
213-
'nodes': 1,
214-
'sockets': 1,
215-
'cores': 100,
216-
'threads': 1,
217-
'gres': {},
218-
},
241+
SlurmDB, 'cluster_resources', return_value={'cores': 100}
219242
):
220243
db = SlurmDB()
221244
summary = db.export_summary('2024-02-01', '2024-02-29')

0 commit comments

Comments
 (0)