Skip to content

Commit e334934

Browse files
committed
style: minor style updates/better code
1 parent 65dc05a commit e334934

File tree

1 file changed

+64
-82
lines changed

1 file changed

+64
-82
lines changed

src/DIRAC/WorkloadManagementSystem/DB/PilotAgentsDB.py

Lines changed: 64 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,18 @@
1717
getGroupedPilotSummary()
1818
1919
"""
20-
import threading
2120
import datetime
2221
import decimal
22+
import threading
2323

24-
from DIRAC import S_OK, S_ERROR
25-
from DIRAC.Core.Base.DB import DB
2624
import DIRAC.Core.Utilities.TimeUtilities as TimeUtilities
27-
from DIRAC.Core.Utilities import DErrno
25+
from DIRAC import S_ERROR, S_OK
26+
from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getDNForUsername, getUsernameForDN, getVOForGroup
2827
from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getCESiteMapping
29-
from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getUsernameForDN, getDNForUsername, getVOForGroup
30-
from DIRAC.ResourceStatusSystem.Client.SiteStatus import SiteStatus
28+
from DIRAC.Core.Base.DB import DB
29+
from DIRAC.Core.Utilities import DErrno
3130
from DIRAC.Core.Utilities.MySQL import _quotedList
31+
from DIRAC.ResourceStatusSystem.Client.SiteStatus import SiteStatus
3232
from DIRAC.WorkloadManagementSystem.Client import PilotStatus
3333

3434

@@ -112,12 +112,8 @@ def setPilotStatus(
112112
setList.append(f"GridSite='{res['Value'][destination]}'")
113113

114114
set_string = ",".join(setList)
115-
req = "UPDATE PilotAgents SET " + set_string + f" WHERE PilotJobReference='{pilotRef}'"
116-
result = self._update(req, conn=conn)
117-
if not result["OK"]:
118-
return result
119-
120-
return S_OK()
115+
req = f"UPDATE PilotAgents SET {set_string} WHERE PilotJobReference='{pilotRef}'"
116+
return self._update(req, conn=conn)
121117

122118
# ###########################################################################################
123119
# FIXME: this can't work ATM because of how the DB table is made. Maybe it would be useful later.
@@ -330,9 +326,9 @@ def getPilotInfo(self, pilotRef=False, parentId=False, conn=False, paramNames=[]
330326
pilotIDs = []
331327
for row in result["Value"]:
332328
pilotDict = {}
333-
for i in range(len(parameters)):
334-
pilotDict[parameters[i]] = row[i]
335-
if parameters[i] == "PilotID":
329+
for i, par in enumerate(parameters):
330+
pilotDict[par] = row[i]
331+
if par == "PilotID":
336332
pilotIDs.append(row[i])
337333
resDict[row[0]] = pilotDict
338334

@@ -341,8 +337,7 @@ def getPilotInfo(self, pilotRef=False, parentId=False, conn=False, paramNames=[]
341337
return S_OK(resDict)
342338

343339
jobsDict = result["Value"]
344-
for pilotRef in resDict:
345-
pilotInfo = resDict[pilotRef]
340+
for pilotRef, pilotInfo in resDict.items():
346341
pilotID = pilotInfo["PilotID"]
347342
if pilotID in jobsDict:
348343
pilotInfo["Jobs"] = jobsDict[pilotID]
@@ -367,16 +362,14 @@ def setPilotBenchmark(self, pilotRef, mark):
367362
"""Set the pilot agent benchmark"""
368363

369364
req = f"UPDATE PilotAgents SET BenchMark='{mark:f}' WHERE PilotJobReference='{pilotRef}'"
370-
result = self._update(req)
371-
return result
365+
return self._update(req)
372366

373367
##########################################################################################
374368
def setAccountingFlag(self, pilotRef, mark="True"):
375369
"""Set the pilot AccountingSent flag"""
376370

377371
req = f"UPDATE PilotAgents SET AccountingSent='{mark}' WHERE PilotJobReference='{pilotRef}'"
378-
result = self._update(req)
379-
return result
372+
return self._update(req)
380373

381374
##########################################################################################
382375
def storePilotOutput(self, pilotRef, output, error):
@@ -409,21 +402,19 @@ def getPilotOutput(self, pilotRef):
409402
result = self._query(req)
410403
if not result["OK"]:
411404
return result
412-
else:
413-
if result["Value"]:
414-
try:
415-
stdout = result["Value"][0][0].decode() # account for the use of BLOBs
416-
error = result["Value"][0][1].decode()
417-
except AttributeError:
418-
stdout = result["Value"][0][0]
419-
error = result["Value"][0][1]
420-
if stdout == '""':
421-
stdout = ""
422-
if error == '""':
423-
error = ""
424-
return S_OK({"StdOut": stdout, "StdErr": error})
425-
else:
426-
return S_ERROR("PilotJobReference " + pilotRef + " not found")
405+
if not result["Value"]:
406+
return S_ERROR(f"PilotJobReference {pilotRef} not found")
407+
try:
408+
stdout = result["Value"][0][0].decode() # account for the use of BLOBs
409+
error = result["Value"][0][1].decode()
410+
except AttributeError:
411+
stdout = result["Value"][0][0]
412+
error = result["Value"][0][1]
413+
if stdout == '""':
414+
stdout = ""
415+
if error == '""':
416+
error = ""
417+
return S_OK({"StdOut": stdout, "StdErr": error})
427418

428419
##########################################################################################
429420
def __getPilotID(self, pilotRef):
@@ -434,19 +425,17 @@ def __getPilotID(self, pilotRef):
434425
result = self._query(req)
435426
if not result["OK"]:
436427
return 0
437-
else:
438-
if result["Value"]:
439-
return int(result["Value"][0][0])
440-
return 0
441-
else:
442-
refString = ",".join(["'" + ref + "'" for ref in pilotRef])
443-
req = f"SELECT PilotID from PilotAgents WHERE PilotJobReference in ( {refString} )"
444-
result = self._query(req)
445-
if not result["OK"]:
446-
return []
447428
if result["Value"]:
448-
return [x[0] for x in result["Value"]]
429+
return int(result["Value"][0][0])
430+
return 0
431+
refString = ",".join(["'" + ref + "'" for ref in pilotRef])
432+
req = f"SELECT PilotID from PilotAgents WHERE PilotJobReference in ( {refString} )"
433+
result = self._query(req)
434+
if not result["OK"]:
449435
return []
436+
if result["Value"]:
437+
return [x[0] for x in result["Value"]]
438+
return []
450439

451440
##########################################################################################
452441
def setJobForPilot(self, jobID, pilotRef, site=None, updateStatus=True):
@@ -455,17 +444,13 @@ def setJobForPilot(self, jobID, pilotRef, site=None, updateStatus=True):
455444
pilotID = self.__getPilotID(pilotRef)
456445
if pilotID:
457446
if updateStatus:
458-
reason = "Report from job %d" % int(jobID)
447+
reason = f"Report from job {jobID}"
459448
result = self.setPilotStatus(pilotRef, status=PilotStatus.RUNNING, statusReason=reason, gridSite=site)
460449
if not result["OK"]:
461450
return result
462-
req = "INSERT INTO JobToPilotMapping (PilotID,JobID,StartTime) VALUES (%d,%d,UTC_TIMESTAMP())" % (
463-
pilotID,
464-
jobID,
465-
)
451+
req = f"INSERT INTO JobToPilotMapping (PilotID,JobID,StartTime) VALUES ({int(pilotID)}, {int(jobID)}, UTC_TIMESTAMP())"
466452
return self._update(req)
467-
else:
468-
return S_ERROR("PilotJobReference " + pilotRef + " not found")
453+
return S_ERROR(f"PilotJobReference {pilotRef} not found")
469454

470455
##########################################################################################
471456
def setCurrentJobID(self, pilotRef, jobID):
@@ -568,16 +553,15 @@ def getPilotSummary(self, startdate="", enddate=""):
568553
result = self._query(req)
569554
if not result["OK"]:
570555
return result
571-
else:
572-
if result["Value"]:
573-
for res in result["Value"]:
574-
site = res[0]
575-
count = res[1]
576-
if site:
577-
if site not in summary_dict:
578-
summary_dict[site] = {}
579-
summary_dict[site][st] = int(count)
580-
summary_dict["Total"][st] += int(count)
556+
if result["Value"]:
557+
for res in result["Value"]:
558+
site = res[0]
559+
count = res[1]
560+
if site:
561+
if site not in summary_dict:
562+
summary_dict[site] = {}
563+
summary_dict[site][st] = int(count)
564+
summary_dict["Total"][st] += int(count)
581565

582566
# Get aborted pilots in the last hour, day
583567
req = "SELECT DestinationSite,count(DestinationSite) FROM PilotAgents WHERE Status='Aborted' AND "
@@ -723,14 +707,12 @@ def _getElementStatus(self, total, eff):
723707
if total > 10:
724708
if eff < 25.0:
725709
return "Bad"
726-
elif eff < 60.0:
710+
if eff < 60.0:
727711
return "Poor"
728-
elif eff < 85.0:
712+
if eff < 85.0:
729713
return "Fair"
730-
else:
731-
return "Good"
732-
else:
733-
return "Idle"
714+
return "Good"
715+
return "Idle"
734716

735717
def getPilotSummaryWeb(self, selectDict, sortList, startItem, maxItems):
736718
"""Get summary of the pilot jobs status by CE/site in a standard structure"""
@@ -861,26 +843,26 @@ def getPilotSummaryWeb(self, selectDict, sortList, startItem, maxItems):
861843

862844
records = []
863845
siteSumDict = {}
864-
for site in resultDict:
846+
for site, ces in resultDict.items():
865847
sumDict = {}
866848
for state in allStateNames:
867849
if state not in sumDict:
868850
sumDict[state] = 0
869851
sumDict["Total"] = 0
870-
for ce in resultDict[site]:
852+
for ce, ceDict in ces.items():
871853
itemList = [site, ce]
872854
total = 0
873855
for state in allStateNames:
874-
itemList.append(resultDict[site][ce][state])
875-
sumDict[state] += resultDict[site][ce][state]
856+
itemList.append(ceDict[state])
857+
sumDict[state] += ceDict[state]
876858
if state == PilotStatus.DONE:
877-
done = resultDict[site][ce][state]
859+
done = ceDict[state]
878860
if state == "Done_Empty":
879-
empty = resultDict[site][ce][state]
861+
empty = ceDict[state]
880862
if state == PilotStatus.ABORTED:
881-
aborted = resultDict[site][ce][state]
882-
if state != "Aborted_Hour" and state != "Done_Empty":
883-
total += resultDict[site][ce][state]
863+
aborted = ceDict[state]
864+
if state not in ("Aborted_Hour", "Done_Empty"):
865+
total += ceDict[state]
884866

885867
sumDict["Total"] += total
886868
# Add the total number of pilots seen in the last day
@@ -915,10 +897,10 @@ def getPilotSummaryWeb(self, selectDict, sortList, startItem, maxItems):
915897
else:
916898
itemList.append("Idle")
917899

918-
if len(resultDict[site]) == 1 or expand_site:
900+
if len(ces) == 1 or expand_site:
919901
records.append(itemList)
920902

921-
if len(resultDict[site]) > 1 and not expand_site:
903+
if len(ces) > 1 and not expand_site:
922904
itemList = [site, "Multiple"]
923905
for state in allStateNames + ["Total"]:
924906
if state in sumDict:
@@ -1210,7 +1192,7 @@ def __init__(self, columnList):
12101192

12111193
self._columns += self.pstates # MySQL._query() does not give us column names, sadly.
12121194

1213-
def buildSQL(self, selectDict=None):
1195+
def buildSQL(self):
12141196
"""
12151197
Build an SQL query to create a table with all status counts in one row, ("pivoted")
12161198
grouped by columns in the column list.

0 commit comments

Comments
 (0)