|
1 | 1 | import FWCore.ParameterSet.Config as cms |
2 | 2 | import os, sys, json |
3 | | -from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter |
| 3 | +from HeterogeneousCore.SonicTriton.customize import getDefaultClientPSet, getParser, getOptions, applyOptions, applyClientOptions |
4 | 4 |
|
5 | 5 | # module/model correspondence |
6 | 6 | models = { |
|
13 | 13 |
|
14 | 14 | # other choices |
15 | 15 | allowed_modes = ["Async","PseudoAsync","Sync"] |
16 | | -allowed_compression = ["none","deflate","gzip"] |
17 | | -allowed_devices = ["auto","cpu","gpu"] |
18 | | -allowed_containers = ["apptainer","docker","podman","podman-hpc"] |
19 | 16 |
|
20 | | -parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) |
21 | | -parser.add_argument("--maxEvents", default=-1, type=int, help="Number of events to process (-1 for all)") |
22 | | -parser.add_argument("--serverName", default="default", type=str, help="name for server (used internally)") |
23 | | -parser.add_argument("--address", default="", type=str, help="server address") |
24 | | -parser.add_argument("--port", default=8001, type=int, help="server port") |
25 | | -parser.add_argument("--timeout", default=30, type=int, help="timeout for requests") |
26 | | -parser.add_argument("--timeoutUnit", default="seconds", type=str, help="unit for timeout") |
27 | | -parser.add_argument("--params", default="", type=str, help="json file containing server address/port") |
28 | | -parser.add_argument("--threads", default=1, type=int, help="number of threads") |
29 | | -parser.add_argument("--streams", default=0, type=int, help="number of streams") |
| 17 | +parser = getParser() |
30 | 18 | parser.add_argument("--modules", metavar=("MODULES"), default=["TritonGraphProducer"], nargs='+', type=str, choices=list(models), help="list of modules to run (choices: %(choices)s)") |
31 | 19 | parser.add_argument("--models", default=["gat_test"], nargs='+', type=str, help="list of models (same length as modules, or just 1 entry if all modules use same model)") |
32 | 20 | parser.add_argument("--mode", default="Async", type=str, choices=allowed_modes, help="mode for client") |
33 | | -parser.add_argument("--verbose", default=False, action="store_true", help="enable all verbose output") |
34 | | -parser.add_argument("--verboseClient", default=False, action="store_true", help="enable verbose output for clients") |
35 | | -parser.add_argument("--verboseServer", default=False, action="store_true", help="enable verbose output for server") |
36 | | -parser.add_argument("--verboseService", default=False, action="store_true", help="enable verbose output for TritonService") |
37 | | -parser.add_argument("--verboseDiscovery", default=False, action="store_true", help="enable verbose output just for server discovery in TritonService") |
38 | 21 | parser.add_argument("--brief", default=False, action="store_true", help="briefer output for graph modules") |
39 | | -parser.add_argument("--fallbackName", default="", type=str, help="name for fallback server") |
40 | 22 | parser.add_argument("--unittest", default=False, action="store_true", help="unit test mode: reduce input sizes") |
41 | 23 | parser.add_argument("--testother", default=False, action="store_true", help="also test gRPC communication if shared memory enabled, or vice versa") |
42 | | -parser.add_argument("--noShm", default=False, action="store_true", help="disable shared memory") |
43 | | -parser.add_argument("--compression", default="", type=str, choices=allowed_compression, help="enable I/O compression") |
44 | | -parser.add_argument("--ssl", default=False, action="store_true", help="enable SSL authentication for server communication") |
45 | | -parser.add_argument("--device", default="auto", type=str.lower, choices=allowed_devices, help="specify device for fallback server") |
46 | | -parser.add_argument("--container", default="apptainer", type=str.lower, choices=allowed_containers, help="specify container for fallback server") |
47 | | -parser.add_argument("--tries", default=0, type=int, help="number of retries for failed request") |
48 | | -options = parser.parse_args() |
49 | 24 |
|
50 | | -if len(options.params)>0: |
51 | | - with open(options.params,'r') as pfile: |
52 | | - pdict = json.load(pfile) |
53 | | - options.address = pdict["address"] |
54 | | - options.port = int(pdict["port"]) |
55 | | - print("server = "+options.address+":"+str(options.port)) |
| 25 | +options = getOptions(parser, verbose=True) |
56 | 26 |
|
57 | 27 | # check models and modules |
58 | 28 | if len(options.modules)!=len(options.models): |
|
68 | 38 | process = cms.Process('tritonTest',enableSonicTriton) |
69 | 39 |
|
70 | 40 | process.load("HeterogeneousCore.SonicTriton.TritonService_cff") |
71 | | - |
72 | | -process.maxEvents = cms.untracked.PSet( input = cms.untracked.int32(options.maxEvents) ) |
73 | | - |
74 | 41 | process.source = cms.Source("EmptySource") |
75 | 42 |
|
76 | | -process.TritonService.verbose = options.verbose or options.verboseService or options.verboseDiscovery |
77 | | -process.TritonService.fallback.verbose = options.verbose or options.verboseServer |
78 | | -process.TritonService.fallback.container = options.container |
79 | | -process.TritonService.fallback.device = options.device |
80 | | -if len(options.fallbackName)>0: |
81 | | - process.TritonService.fallback.instanceBaseName = options.fallbackName |
82 | | -if len(options.address)>0: |
83 | | - process.TritonService.servers.append( |
84 | | - cms.PSet( |
85 | | - name = cms.untracked.string(options.serverName), |
86 | | - address = cms.untracked.string(options.address), |
87 | | - port = cms.untracked.uint32(options.port), |
88 | | - useSsl = cms.untracked.bool(options.ssl), |
89 | | - rootCertificates = cms.untracked.string(""), |
90 | | - privateKey = cms.untracked.string(""), |
91 | | - certificateChain = cms.untracked.string(""), |
92 | | - ) |
93 | | - ) |
94 | | - |
95 | 43 | # Let it run |
96 | 44 | process.p = cms.Path() |
97 | 45 |
|
|
101 | 49 | "Analyzer": cms.EDAnalyzer, |
102 | 50 | } |
103 | 51 |
|
104 | | -keepMsgs = [] |
105 | | -if options.verbose or options.verboseDiscovery: |
106 | | - keepMsgs.append('TritonDiscovery') |
107 | | -if options.verbose or options.verboseClient: |
108 | | - keepMsgs.append('TritonClient') |
109 | | -if options.verbose or options.verboseService: |
110 | | - keepMsgs.append('TritonService') |
| 52 | +defaultClient = applyClientOptions(getDefaultClientPSet().clone(), options) |
111 | 53 |
|
112 | 54 | for im,module in enumerate(options.modules): |
113 | 55 | model = options.models[im] |
114 | 56 | Module = [obj for name,obj in modules.items() if name in module][0] |
115 | 57 | setattr(process, module, |
116 | 58 | Module(module, |
117 | | - Client = cms.PSet( |
| 59 | + Client = defaultClient.clone( |
118 | 60 | mode = cms.string(options.mode), |
119 | 61 | preferredServer = cms.untracked.string(""), |
120 | | - timeout = cms.untracked.uint32(options.timeout), |
121 | | - timeoutUnit = cms.untracked.string(options.timeoutUnit), |
122 | 62 | modelName = cms.string(model), |
123 | 63 | modelVersion = cms.string(""), |
124 | 64 | modelConfigPath = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/config.pbtxt".format(model)), |
125 | | - verbose = cms.untracked.bool(options.verbose or options.verboseClient), |
126 | | - allowedTries = cms.untracked.uint32(options.tries), |
127 | | - useSharedMemory = cms.untracked.bool(not options.noShm), |
128 | | - compression = cms.untracked.string(options.compression), |
129 | 65 | ) |
130 | 66 | ) |
131 | 67 | ) |
|
148 | 84 | processModule.edgeMax = cms.uint32(15000) |
149 | 85 | processModule.brief = cms.bool(options.brief) |
150 | 86 | process.p += processModule |
151 | | - if options.verbose or options.verboseClient: |
152 | | - keepMsgs.extend([module,module+':TritonClient']) |
153 | 87 | if options.testother: |
154 | 88 | # clone modules to test both gRPC and shared memory |
155 | 89 | _module2 = module+"GRPC" if processModule.Client.useSharedMemory else "SHM" |
|
160 | 94 | ) |
161 | 95 | processModule2 = getattr(process, _module2) |
162 | 96 | process.p += processModule2 |
163 | | - if options.verbose or options.verboseClient: |
164 | | - keepMsgs.extend([_module2,_module2+':TritonClient']) |
165 | | - |
166 | | -process.load('FWCore/MessageService/MessageLogger_cfi') |
167 | | -process.MessageLogger.cerr.FwkReport.reportEvery = 500 |
168 | | -for msg in keepMsgs: |
169 | | - setattr(process.MessageLogger.cerr,msg, |
170 | | - cms.untracked.PSet( |
171 | | - limit = cms.untracked.int32(10000000), |
172 | | - ) |
173 | | - ) |
174 | | - |
175 | | -if options.threads>0: |
176 | | - process.options.numberOfThreads = options.threads |
177 | | - process.options.numberOfStreams = options.streams |
178 | 97 |
|
| 98 | +process = applyOptions(process, options) |
0 commit comments