Skip to content

Commit d44c9cc

Browse files
committed
introduce useful customization functions for SonicTriton tests
1 parent 325625d commit d44c9cc

File tree

2 files changed

+177
-86
lines changed

2 files changed

+177
-86
lines changed
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import FWCore.ParameterSet.Config as cms
2+
3+
def getDefaultClientPSet():
4+
from HeterogeneousCore.SonicTriton.TritonGraphAnalyzer import TritonGraphAnalyzer
5+
temp = TritonGraphAnalyzer()
6+
return temp.Client
7+
8+
def getParser():
9+
allowed_compression = ["none","deflate","gzip"]
10+
allowed_devices = ["auto","cpu","gpu"]
11+
allowed_containers = ["apptainer","docker","podman","podman-hpc"]
12+
13+
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
14+
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
15+
parser.add_argument("--maxEvents", default=-1, type=int, help="Number of events to process (-1 for all)")
16+
parser.add_argument("--serverName", default="default", type=str, help="name for server (used internally)")
17+
parser.add_argument("--address", default="", type=str, help="server address")
18+
parser.add_argument("--port", default=8001, type=int, help="server port")
19+
parser.add_argument("--timeout", default=30, type=int, help="timeout for requests")
20+
parser.add_argument("--timeoutUnit", default="seconds", type=str, help="unit for timeout")
21+
parser.add_argument("--params", default="", type=str, help="json file containing server address/port")
22+
parser.add_argument("--threads", default=1, type=int, help="number of threads")
23+
parser.add_argument("--streams", default=0, type=int, help="number of streams")
24+
parser.add_argument("--verbose", default=False, action="store_true", help="enable all verbose output")
25+
parser.add_argument("--verboseClient", default=False, action="store_true", help="enable verbose output for clients")
26+
parser.add_argument("--verboseServer", default=False, action="store_true", help="enable verbose output for server")
27+
parser.add_argument("--verboseService", default=False, action="store_true", help="enable verbose output for TritonService")
28+
parser.add_argument("--verboseDiscovery", default=False, action="store_true", help="enable verbose output just for server discovery in TritonService")
29+
parser.add_argument("--noShm", default=False, action="store_true", help="disable shared memory")
30+
parser.add_argument("--compression", default="", type=str, choices=allowed_compression, help="enable I/O compression")
31+
parser.add_argument("--ssl", default=False, action="store_true", help="enable SSL authentication for server communication")
32+
parser.add_argument("--tries", default=0, type=int, help="number of retries for failed request")
33+
parser.add_argument("--device", default="auto", type=str.lower, choices=allowed_devices, help="specify device for fallback server")
34+
parser.add_argument("--container", default="apptainer", type=str.lower, choices=allowed_containers, help="specify container for fallback server")
35+
parser.add_argument("--fallbackName", default="", type=str, help="name for fallback server")
36+
parser.add_argument("--imageName", default="", type=str, help="container image name for fallback server")
37+
parser.add_argument("--tempDir", default="", type=str, help="temp directory for fallback server")
38+
39+
return parser
40+
41+
def getOptions(parser, verbose=False):
42+
options = parser.parse_args()
43+
44+
if len(options.params)>0:
45+
with open(options.params,'r') as pfile:
46+
pdict = json.load(pfile)
47+
options.address = pdict["address"]
48+
options.port = int(pdict["port"])
49+
if verbose: print("server = "+options.address+":"+str(options.port))
50+
51+
return options
52+
53+
def applyOptions(process, options, applyToModules=False):
54+
process.maxEvents.input = cms.untracked.int32(options.maxEvents)
55+
56+
if options.threads>0:
57+
process.options.numberOfThreads = options.threads
58+
process.options.numberOfStreams = options.streams
59+
60+
if options.verbose:
61+
configureLoggingAll(process)
62+
else:
63+
configureLogging(process,
64+
client=options.verboseClient,
65+
server=options.verboseServer,
66+
service=options.verboseService,
67+
discovery=options.verboseDiscovery
68+
)
69+
70+
if hasattr(process,'TritonService'):
71+
process.TritonService.fallback.container = options.container
72+
process.TritonService.fallback.imageName = options.imageName
73+
process.TritonService.fallback.tempDir = options.tempDir
74+
process.TritonService.fallback.device = options.device
75+
if len(options.fallbackName)>0:
76+
process.TritonService.fallback.instanceBaseName = options.fallbackName
77+
if len(options.address)>0:
78+
process.TritonService.servers.append(
79+
cms.PSet(
80+
name = cms.untracked.string(options.serverName),
81+
address = cms.untracked.string(options.address),
82+
port = cms.untracked.uint32(options.port),
83+
useSsl = cms.untracked.bool(options.ssl),
84+
rootCertificates = cms.untracked.string(""),
85+
privateKey = cms.untracked.string(""),
86+
certificateChain = cms.untracked.string(""),
87+
)
88+
)
89+
90+
if applyToModules:
91+
process = configureModules(process, **getClientOptions(options))
92+
93+
return process
94+
95+
def getClientOptions(options):
96+
return dict(
97+
compression = options.compression,
98+
useSharedMemory = not options.noShm,
99+
timeout = options.timeout,
100+
timeoutUnit = options.timeoutUnit,
101+
allowedTries = options.tries,
102+
)
103+
104+
def applyClientOptions(client, options):
105+
return configureClient(client, **getClientOptions(options))
106+
107+
def configureModules(process, modules=None, **kwargs):
108+
if modules is None:
109+
modules = {}
110+
modules.update(process._Process__producers)
111+
modules.update(process._Process__filters)
112+
modules.update(process._Process__analyzers)
113+
configured = []
114+
for pname,producer in modules.items():
115+
if hasattr(producer,'Client'):
116+
producer.Client = configureClient(producer.Client, **kwargs)
117+
configured.append(pname)
118+
return process, configured
119+
120+
def configureClient(client, **kwargs):
121+
client.update_(kwargs)
122+
return client
123+
124+
def configureLogging(process, client=False, server=False, service=False, discovery=False):
125+
if not any([client, server, service, discovery]):
126+
return
127+
128+
keepMsgs = []
129+
if discovery:
130+
keepMsgs.append('TritonDiscovery')
131+
if client:
132+
keepMsgs.append('TritonClient')
133+
if service:
134+
keepMsgs.append('TritonService')
135+
136+
if hasattr(process,'TritonService'):
137+
process.TritonService.verbose = service or discovery
138+
process.TritonService.fallback.verbose = server
139+
if client:
140+
process, configured = configureModules(process, verbose = True)
141+
for module in configured:
142+
keepMsgs.extend([module, module+':TritonClient'])
143+
144+
if not hasattr(process,'MessageLogger'):
145+
process.load('FWCore/MessageService/MessageLogger_cfi')
146+
process.MessageLogger.cerr.FwkReport.reportEvery = 500
147+
for msg in keepMsgs:
148+
setattr(process.MessageLogger.cerr, msg,
149+
cms.untracked.PSet(
150+
limit = cms.untracked.int32(10000000),
151+
)
152+
)
153+
154+
return process
155+
156+
# dedicated functions for cmsDriver customization
157+
158+
def configureLoggingClient(process):
159+
return configureLogging(process, client=True)
160+
161+
def configureLoggingServer(process):
162+
return configureLogging(process, server=True)
163+
164+
def configureLoggingService(process):
165+
return configureLogging(process, service=True)
166+
167+
def configureLoggingDiscovery(process):
168+
return configureLogging(process, discovery=True)
169+
170+
def configureLoggingAll(process):
171+
return configureLogging(process, client=True, server=True, service=True, discovery=True)
Lines changed: 6 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import FWCore.ParameterSet.Config as cms
22
import os, sys, json
3-
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
3+
from HeterogeneousCore.SonicTriton.customize import getDefaultClientPSet, getParser, getOptions, applyOptions, applyClientOptions
44

55
# module/model correspondence
66
models = {
@@ -13,46 +13,16 @@
1313

1414
# other choices
1515
allowed_modes = ["Async","PseudoAsync","Sync"]
16-
allowed_compression = ["none","deflate","gzip"]
17-
allowed_devices = ["auto","cpu","gpu"]
18-
allowed_containers = ["apptainer","docker","podman","podman-hpc"]
1916

20-
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
21-
parser.add_argument("--maxEvents", default=-1, type=int, help="Number of events to process (-1 for all)")
22-
parser.add_argument("--serverName", default="default", type=str, help="name for server (used internally)")
23-
parser.add_argument("--address", default="", type=str, help="server address")
24-
parser.add_argument("--port", default=8001, type=int, help="server port")
25-
parser.add_argument("--timeout", default=30, type=int, help="timeout for requests")
26-
parser.add_argument("--timeoutUnit", default="seconds", type=str, help="unit for timeout")
27-
parser.add_argument("--params", default="", type=str, help="json file containing server address/port")
28-
parser.add_argument("--threads", default=1, type=int, help="number of threads")
29-
parser.add_argument("--streams", default=0, type=int, help="number of streams")
17+
parser = getParser()
3018
parser.add_argument("--modules", metavar=("MODULES"), default=["TritonGraphProducer"], nargs='+', type=str, choices=list(models), help="list of modules to run (choices: %(choices)s)")
3119
parser.add_argument("--models", default=["gat_test"], nargs='+', type=str, help="list of models (same length as modules, or just 1 entry if all modules use same model)")
3220
parser.add_argument("--mode", default="Async", type=str, choices=allowed_modes, help="mode for client")
33-
parser.add_argument("--verbose", default=False, action="store_true", help="enable all verbose output")
34-
parser.add_argument("--verboseClient", default=False, action="store_true", help="enable verbose output for clients")
35-
parser.add_argument("--verboseServer", default=False, action="store_true", help="enable verbose output for server")
36-
parser.add_argument("--verboseService", default=False, action="store_true", help="enable verbose output for TritonService")
37-
parser.add_argument("--verboseDiscovery", default=False, action="store_true", help="enable verbose output just for server discovery in TritonService")
3821
parser.add_argument("--brief", default=False, action="store_true", help="briefer output for graph modules")
39-
parser.add_argument("--fallbackName", default="", type=str, help="name for fallback server")
4022
parser.add_argument("--unittest", default=False, action="store_true", help="unit test mode: reduce input sizes")
4123
parser.add_argument("--testother", default=False, action="store_true", help="also test gRPC communication if shared memory enabled, or vice versa")
42-
parser.add_argument("--noShm", default=False, action="store_true", help="disable shared memory")
43-
parser.add_argument("--compression", default="", type=str, choices=allowed_compression, help="enable I/O compression")
44-
parser.add_argument("--ssl", default=False, action="store_true", help="enable SSL authentication for server communication")
45-
parser.add_argument("--device", default="auto", type=str.lower, choices=allowed_devices, help="specify device for fallback server")
46-
parser.add_argument("--container", default="apptainer", type=str.lower, choices=allowed_containers, help="specify container for fallback server")
47-
parser.add_argument("--tries", default=0, type=int, help="number of retries for failed request")
48-
options = parser.parse_args()
4924

50-
if len(options.params)>0:
51-
with open(options.params,'r') as pfile:
52-
pdict = json.load(pfile)
53-
options.address = pdict["address"]
54-
options.port = int(pdict["port"])
55-
print("server = "+options.address+":"+str(options.port))
25+
options = getOptions(parser, verbose=True)
5626

5727
# check models and modules
5828
if len(options.modules)!=len(options.models):
@@ -68,30 +38,8 @@
6838
process = cms.Process('tritonTest',enableSonicTriton)
6939

7040
process.load("HeterogeneousCore.SonicTriton.TritonService_cff")
71-
72-
process.maxEvents = cms.untracked.PSet( input = cms.untracked.int32(options.maxEvents) )
73-
7441
process.source = cms.Source("EmptySource")
7542

76-
process.TritonService.verbose = options.verbose or options.verboseService or options.verboseDiscovery
77-
process.TritonService.fallback.verbose = options.verbose or options.verboseServer
78-
process.TritonService.fallback.container = options.container
79-
process.TritonService.fallback.device = options.device
80-
if len(options.fallbackName)>0:
81-
process.TritonService.fallback.instanceBaseName = options.fallbackName
82-
if len(options.address)>0:
83-
process.TritonService.servers.append(
84-
cms.PSet(
85-
name = cms.untracked.string(options.serverName),
86-
address = cms.untracked.string(options.address),
87-
port = cms.untracked.uint32(options.port),
88-
useSsl = cms.untracked.bool(options.ssl),
89-
rootCertificates = cms.untracked.string(""),
90-
privateKey = cms.untracked.string(""),
91-
certificateChain = cms.untracked.string(""),
92-
)
93-
)
94-
9543
# Let it run
9644
process.p = cms.Path()
9745

@@ -101,31 +49,19 @@
10149
"Analyzer": cms.EDAnalyzer,
10250
}
10351

104-
keepMsgs = []
105-
if options.verbose or options.verboseDiscovery:
106-
keepMsgs.append('TritonDiscovery')
107-
if options.verbose or options.verboseClient:
108-
keepMsgs.append('TritonClient')
109-
if options.verbose or options.verboseService:
110-
keepMsgs.append('TritonService')
52+
defaultClient = applyClientOptions(getDefaultClientPSet().clone(), options)
11153

11254
for im,module in enumerate(options.modules):
11355
model = options.models[im]
11456
Module = [obj for name,obj in modules.items() if name in module][0]
11557
setattr(process, module,
11658
Module(module,
117-
Client = cms.PSet(
59+
Client = defaultClient.clone(
11860
mode = cms.string(options.mode),
11961
preferredServer = cms.untracked.string(""),
120-
timeout = cms.untracked.uint32(options.timeout),
121-
timeoutUnit = cms.untracked.string(options.timeoutUnit),
12262
modelName = cms.string(model),
12363
modelVersion = cms.string(""),
12464
modelConfigPath = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/config.pbtxt".format(model)),
125-
verbose = cms.untracked.bool(options.verbose or options.verboseClient),
126-
allowedTries = cms.untracked.uint32(options.tries),
127-
useSharedMemory = cms.untracked.bool(not options.noShm),
128-
compression = cms.untracked.string(options.compression),
12965
)
13066
)
13167
)
@@ -148,8 +84,6 @@
14884
processModule.edgeMax = cms.uint32(15000)
14985
processModule.brief = cms.bool(options.brief)
15086
process.p += processModule
151-
if options.verbose or options.verboseClient:
152-
keepMsgs.extend([module,module+':TritonClient'])
15387
if options.testother:
15488
# clone modules to test both gRPC and shared memory
15589
_module2 = module+"GRPC" if processModule.Client.useSharedMemory else "SHM"
@@ -160,19 +94,5 @@
16094
)
16195
processModule2 = getattr(process, _module2)
16296
process.p += processModule2
163-
if options.verbose or options.verboseClient:
164-
keepMsgs.extend([_module2,_module2+':TritonClient'])
165-
166-
process.load('FWCore/MessageService/MessageLogger_cfi')
167-
process.MessageLogger.cerr.FwkReport.reportEvery = 500
168-
for msg in keepMsgs:
169-
setattr(process.MessageLogger.cerr,msg,
170-
cms.untracked.PSet(
171-
limit = cms.untracked.int32(10000000),
172-
)
173-
)
174-
175-
if options.threads>0:
176-
process.options.numberOfThreads = options.threads
177-
process.options.numberOfStreams = options.streams
17897

98+
process = applyOptions(process, options)

0 commit comments

Comments
 (0)