Skip to content

Commit ff23a73

Browse files
committed
slurm2sql: seff/sacct: Make --user compatible with --db
1 parent 514b78c commit ff23a73

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

slurm2sql.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,7 @@ def main(argv=sys.argv[1:], db=None, raw_sacct=None, csv_input=None):
703703
logging.lastResort.setLevel(logging.WARN)
704704
LOG.debug(args)
705705

706-
sacct_filter = process_sacct_filter(args, sacct_filter)
706+
sacct_filter = args_to_sacct_filter(args, sacct_filter)
707707

708708
# db is only given as an argument in tests (normally)
709709
if db is None:
@@ -982,7 +982,7 @@ def rows():
982982
return errors[0]
983983

984984

985-
def process_sacct_filter(args, sacct_filter):
985+
def args_to_sacct_filter(args, sacct_filter):
986986
"""Generate sacct filter args in a standard way
987987
988988
For example adding a --completed argument that translates into
@@ -1005,6 +1005,13 @@ def process_sacct_filter(args, sacct_filter):
10051005
args.running_at_time = None
10061006
return sacct_filter
10071007

1008+
def args_to_sql_where(args):
1009+
where = [ ]
1010+
if getattr(args, 'user', None):
1011+
where.append('and user=:user')
1012+
return ' '.join(where)
1013+
1014+
10081015
def import_or_open_db(args, sacct_filter, csv_input=None):
10091016
"""Helper function to either open a DB or generate a new in-mem one from sacct
10101017
@@ -1021,7 +1028,7 @@ def import_or_open_db(args, sacct_filter, csv_input=None):
10211028
LOG.warn("Warning: reading from database. Any sacct filters are ignored.")
10221029
else:
10231030
# Import fresh
1024-
sacct_filter = process_sacct_filter(args, sacct_filter)
1031+
sacct_filter = args_to_sacct_filter(args, sacct_filter)
10251032
LOG.debug(f'sacct args: {sacct_filter}')
10261033
db = sqlite3.connect(':memory:')
10271034
errors = slurm2sql(db, sacct_filter=sacct_filter,
@@ -1125,13 +1132,10 @@ def sacct_cli(argv=sys.argv[1:], csv_input=None):
11251132
db = import_or_open_db(args, sacct_filter, csv_input=csv_input)
11261133

11271134
# If we run sacct, then args.user is set to None so we don't do double filtering here
1128-
if args.user:
1129-
where_user = "WHERE user=:user"
1130-
else:
1131-
where_user = ''
1135+
where = args_to_sql_where(args)
11321136

11331137
from tabulate import tabulate
1134-
cur = db.execute(f'select {args.output} from slurm {where_user}', {'user':args.user})
1138+
cur = db.execute(f'select {args.output} from slurm WHERE true {where}', {'user':args.user})
11351139
headers = [ x[0] for x in cur.description ]
11361140
print(tabulate(cur, headers=headers, tablefmt=args.format))
11371141

@@ -1189,15 +1193,11 @@ def seff_cli(argv=sys.argv[1:], csv_input=None):
11891193
else:
11901194
order_by = ''
11911195

1192-
# If we run sacct, then args.user is set to None so we don't do double filtering here
1193-
if args.user:
1194-
where_user = "and user=:user"
1195-
else:
1196-
where_user = ''
1197-
1198-
11991196
db = import_or_open_db(args, sacct_filter, csv_input=csv_input)
12001197

1198+
# If we run sacct, then args.user is set to None so we don't do double filtering here
1199+
where = args_to_sql_where(args)
1200+
12011201
from tabulate import tabulate
12021202

12031203
if args.aggregate_user:
@@ -1218,7 +1218,7 @@ def seff_cli(argv=sys.argv[1:], csv_input=None):
12181218
round(sum(TotDiskWrite/1048576)/sum(Elapsed),2) AS write_MiBps
12191219
12201220
FROM eff
1221-
WHERE End IS NOT NULL {where_user}
1221+
WHERE End IS NOT NULL WHERE true {where}
12221222
GROUP BY user ) {order_by}
12231223
""", {'user': args.user})
12241224
headers = [ x[0] for x in cur.description ]
@@ -1249,7 +1249,7 @@ def seff_cli(argv=sys.argv[1:], csv_input=None):
12491249
round(TotDiskWrite/Elapsed/1048576,2) AS write_MiBps
12501250
12511251
FROM eff
1252-
WHERE End IS NOT NULL {where_user} ) {order_by}""", {'user': args.user})
1252+
WHERE End IS NOT NULL {where} ) {order_by}""", {'user': args.user})
12531253
headers = [ x[0] for x in cur.description ]
12541254
data = cur.fetchall()
12551255
if len(data) == 0:

0 commit comments

Comments
 (0)