|
| 1 | +import FWCore.ParameterSet.Config as cms |
| 2 | + |
| 3 | +def getDefaultClientPSet(): |
| 4 | + from HeterogeneousCore.SonicTriton.TritonGraphAnalyzer import TritonGraphAnalyzer |
| 5 | + temp = TritonGraphAnalyzer() |
| 6 | + return temp.Client |
| 7 | + |
| 8 | +def getParser(): |
| 9 | + allowed_compression = ["none","deflate","gzip"] |
| 10 | + allowed_devices = ["auto","cpu","gpu"] |
| 11 | + allowed_containers = ["apptainer","docker","podman","podman-hpc"] |
| 12 | + |
| 13 | + from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter |
| 14 | + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) |
| 15 | + parser.add_argument("--maxEvents", default=-1, type=int, help="Number of events to process (-1 for all)") |
| 16 | + parser.add_argument("--serverName", default="default", type=str, help="name for server (used internally)") |
| 17 | + parser.add_argument("--address", default="", type=str, help="server address") |
| 18 | + parser.add_argument("--port", default=8001, type=int, help="server port") |
| 19 | + parser.add_argument("--timeout", default=30, type=int, help="timeout for requests") |
| 20 | + parser.add_argument("--timeoutUnit", default="seconds", type=str, help="unit for timeout") |
| 21 | + parser.add_argument("--params", default="", type=str, help="json file containing server address/port") |
| 22 | + parser.add_argument("--threads", default=1, type=int, help="number of threads") |
| 23 | + parser.add_argument("--streams", default=0, type=int, help="number of streams") |
| 24 | + parser.add_argument("--verbose", default=False, action="store_true", help="enable all verbose output") |
| 25 | + parser.add_argument("--verboseClient", default=False, action="store_true", help="enable verbose output for clients") |
| 26 | + parser.add_argument("--verboseServer", default=False, action="store_true", help="enable verbose output for server") |
| 27 | + parser.add_argument("--verboseService", default=False, action="store_true", help="enable verbose output for TritonService") |
| 28 | + parser.add_argument("--verboseDiscovery", default=False, action="store_true", help="enable verbose output just for server discovery in TritonService") |
| 29 | + parser.add_argument("--noShm", default=False, action="store_true", help="disable shared memory") |
| 30 | + parser.add_argument("--compression", default="", type=str, choices=allowed_compression, help="enable I/O compression") |
| 31 | + parser.add_argument("--ssl", default=False, action="store_true", help="enable SSL authentication for server communication") |
| 32 | + parser.add_argument("--tries", default=0, type=int, help="number of retries for failed request") |
| 33 | + parser.add_argument("--device", default="auto", type=str.lower, choices=allowed_devices, help="specify device for fallback server") |
| 34 | + parser.add_argument("--container", default="apptainer", type=str.lower, choices=allowed_containers, help="specify container for fallback server") |
| 35 | + parser.add_argument("--fallbackName", default="", type=str, help="name for fallback server") |
| 36 | + parser.add_argument("--imageName", default="", type=str, help="container image name for fallback server") |
| 37 | + parser.add_argument("--tempDir", default="", type=str, help="temp directory for fallback server") |
| 38 | + |
| 39 | + return parser |
| 40 | + |
| 41 | +def getOptions(parser, verbose=False): |
| 42 | + options = parser.parse_args() |
| 43 | + |
| 44 | + if len(options.params)>0: |
| 45 | + with open(options.params,'r') as pfile: |
| 46 | + pdict = json.load(pfile) |
| 47 | + options.address = pdict["address"] |
| 48 | + options.port = int(pdict["port"]) |
| 49 | + if verbose: print("server = "+options.address+":"+str(options.port)) |
| 50 | + |
| 51 | + return options |
| 52 | + |
| 53 | +def applyOptions(process, options, applyToModules=False): |
| 54 | + process.maxEvents.input = cms.untracked.int32(options.maxEvents) |
| 55 | + |
| 56 | + if options.threads>0: |
| 57 | + process.options.numberOfThreads = options.threads |
| 58 | + process.options.numberOfStreams = options.streams |
| 59 | + |
| 60 | + if options.verbose: |
| 61 | + configureLoggingAll(process) |
| 62 | + else: |
| 63 | + configureLogging(process, |
| 64 | + client=options.verboseClient, |
| 65 | + server=options.verboseServer, |
| 66 | + service=options.verboseService, |
| 67 | + discovery=options.verboseDiscovery |
| 68 | + ) |
| 69 | + |
| 70 | + if hasattr(process,'TritonService'): |
| 71 | + process.TritonService.fallback.container = options.container |
| 72 | + process.TritonService.fallback.imageName = options.imageName |
| 73 | + process.TritonService.fallback.tempDir = options.tempDir |
| 74 | + process.TritonService.fallback.device = options.device |
| 75 | + if len(options.fallbackName)>0: |
| 76 | + process.TritonService.fallback.instanceBaseName = options.fallbackName |
| 77 | + if len(options.address)>0: |
| 78 | + process.TritonService.servers.append( |
| 79 | + cms.PSet( |
| 80 | + name = cms.untracked.string(options.serverName), |
| 81 | + address = cms.untracked.string(options.address), |
| 82 | + port = cms.untracked.uint32(options.port), |
| 83 | + useSsl = cms.untracked.bool(options.ssl), |
| 84 | + rootCertificates = cms.untracked.string(""), |
| 85 | + privateKey = cms.untracked.string(""), |
| 86 | + certificateChain = cms.untracked.string(""), |
| 87 | + ) |
| 88 | + ) |
| 89 | + |
| 90 | + if applyToModules: |
| 91 | + process = configureModules(process, **getClientOptions(options)) |
| 92 | + |
| 93 | + return process |
| 94 | + |
| 95 | +def getClientOptions(options): |
| 96 | + return dict( |
| 97 | + compression = options.compression, |
| 98 | + useSharedMemory = not options.noShm, |
| 99 | + timeout = options.timeout, |
| 100 | + timeoutUnit = options.timeoutUnit, |
| 101 | + allowedTries = options.tries, |
| 102 | + ) |
| 103 | + |
| 104 | +def applyClientOptions(client, options): |
| 105 | + return configureClient(client, **getClientOptions(options)) |
| 106 | + |
| 107 | +def configureModules(process, modules=None, **kwargs): |
| 108 | + if modules is None: |
| 109 | + modules = {} |
| 110 | + modules.update(process._Process__producers) |
| 111 | + modules.update(process._Process__filters) |
| 112 | + modules.update(process._Process__analyzers) |
| 113 | + configured = [] |
| 114 | + for pname,producer in modules.items(): |
| 115 | + if hasattr(producer,'Client'): |
| 116 | + producer.Client = configureClient(producer.Client, **kwargs) |
| 117 | + configured.append(pname) |
| 118 | + return process, configured |
| 119 | + |
| 120 | +def configureClient(client, **kwargs): |
| 121 | + client.update_(kwargs) |
| 122 | + return client |
| 123 | + |
| 124 | +def configureLogging(process, client=False, server=False, service=False, discovery=False): |
| 125 | + if not any([client, server, service, discovery]): |
| 126 | + return |
| 127 | + |
| 128 | + keepMsgs = [] |
| 129 | + if discovery: |
| 130 | + keepMsgs.append('TritonDiscovery') |
| 131 | + if client: |
| 132 | + keepMsgs.append('TritonClient') |
| 133 | + if service: |
| 134 | + keepMsgs.append('TritonService') |
| 135 | + |
| 136 | + if hasattr(process,'TritonService'): |
| 137 | + process.TritonService.verbose = service or discovery |
| 138 | + process.TritonService.fallback.verbose = server |
| 139 | + if client: |
| 140 | + process, configured = configureModules(process, verbose = True) |
| 141 | + for module in configured: |
| 142 | + keepMsgs.extend([module, module+':TritonClient']) |
| 143 | + |
| 144 | + if not hasattr(process,'MessageLogger'): |
| 145 | + process.load('FWCore/MessageService/MessageLogger_cfi') |
| 146 | + process.MessageLogger.cerr.FwkReport.reportEvery = 500 |
| 147 | + for msg in keepMsgs: |
| 148 | + setattr(process.MessageLogger.cerr, msg, |
| 149 | + cms.untracked.PSet( |
| 150 | + limit = cms.untracked.int32(10000000), |
| 151 | + ) |
| 152 | + ) |
| 153 | + |
| 154 | + return process |
| 155 | + |
| 156 | +# dedicated functions for cmsDriver customization |
| 157 | + |
| 158 | +def configureLoggingClient(process): |
| 159 | + return configureLogging(process, client=True) |
| 160 | + |
| 161 | +def configureLoggingServer(process): |
| 162 | + return configureLogging(process, server=True) |
| 163 | + |
| 164 | +def configureLoggingService(process): |
| 165 | + return configureLogging(process, service=True) |
| 166 | + |
| 167 | +def configureLoggingDiscovery(process): |
| 168 | + return configureLogging(process, discovery=True) |
| 169 | + |
| 170 | +def configureLoggingAll(process): |
| 171 | + return configureLogging(process, client=True, server=True, service=True, discovery=True) |
0 commit comments