@@ -43,6 +43,9 @@ AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) {
4343 args.maxSeqLen = 0 ;
4444 args.netTurbo = true ;
4545 args.gpuIndex = -1 ;
46+ args.gpuSegmentFrom = -1 ;
47+ args.gpuSegmentTo = -1 ;
48+
4649 int i = 1 ;
4750 if (requireMode && argc > 1 ) {
4851 args.mode = argv[1 ];
@@ -79,15 +82,15 @@ AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) {
7982
8083 for (int s = 0 ; s < count; s++) {
8184 char *v = argv[i + 1 + s];
82- char *sep = std::strstr (v, " :" );
83- if (sep == NULL ) {
85+ char *separator = std::strstr (v, " :" );
86+ if (separator == NULL ) {
8487 throw std::runtime_error (" Invalid worker address: " + std::string (v));
8588 }
86- int hostLen = sep - v;
89+ int hostLen = separator - v;
8790 args.workerHosts [s] = new char [hostLen + 1 ];
8891 std::memcpy (args.workerHosts [s], v, hostLen);
8992 args.workerHosts [s][hostLen] = ' \0 ' ;
90- args.workerPorts [s] = std::atoi (sep + 1 );
93+ args.workerPorts [s] = std::atoi (separator + 1 );
9194 }
9295
9396 i += count - 1 ;
@@ -109,6 +112,12 @@ AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) {
109112 args.maxSeqLen = (unsigned int )atoi (value);
110113 } else if (std::strcmp (name, " --gpu-index" ) == 0 ) {
111114 args.gpuIndex = atoi (value);
115+ } else if (std::strcmp (name, " --gpu-segments" ) == 0 ) {
116+ char *separator = std::strstr (value, " :" );
117+ if (separator == NULL )
118+ throw std::runtime_error (" GPU segments expected in the format <from>:<to>" );
119+ args.gpuSegmentFrom = atoi (value);
120+ args.gpuSegmentTo = atoi (separator + 1 );
112121 } else if (std::strcmp (name, " --net-turbo" ) == 0 ) {
113122 args.netTurbo = atoi (value) == 1 ;
114123 } else {
@@ -128,23 +137,32 @@ AppCliArgs::~AppCliArgs() {
128137 delete[] workerPorts;
129138}
130139
131- static NnDevice *createDevice (AppCliArgs *args, NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution) {
140+ static std::vector<NnExecutorDevice> resolveDevices (AppCliArgs *args, NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution) {
141+ std::vector<NnExecutorDevice> devices;
142+
132143 if (args->gpuIndex >= 0 ) {
133144#if defined(DLLAMA_VULKAN)
134- return new NnVulkanDevice (args->gpuIndex , netConfig, nodeConfig, netExecution);
145+ devices.push_back (NnExecutorDevice (
146+ new NnVulkanDevice (args->gpuIndex , netConfig, nodeConfig, netExecution),
147+ args->gpuSegmentFrom ,
148+ args->gpuSegmentTo
149+ ));
135150#else
136151 throw std::runtime_error (" This build does not support GPU" );
137152#endif
138153 }
139- return new NnCpuDevice (netConfig, nodeConfig, netExecution);
154+
155+ if (args->gpuIndex < 0 || (args->gpuSegmentFrom >= 0 && args->gpuSegmentTo >= 0 )) {
156+ devices.push_back (NnExecutorDevice (new NnCpuDevice (netConfig, nodeConfig, netExecution), -1 , -1 ));
157+ }
158+ return devices;
140159}
141160
142- RootLlmInference::RootLlmInference (LlmNet *net, NnDevice *device, NnNetExecution *execution, NnExecutor *executor, NnNetwork *network) {
161+ RootLlmInference::RootLlmInference (LlmNet *net, NnNetExecution *execution, NnExecutor *executor, NnNetwork *network) {
143162 this ->header = net->header ;
144163 this ->tokenPipe = (float *)execution->pipes [net->tokenPipeIndex ];
145164 this ->positionPipe = (float *)execution->pipes [net->positionPipeIndex ];
146165 this ->logitsPipe = (float *)execution->pipes [net->logitsPipeIndex ];
147- this ->device = device;
148166 this ->execution = execution;
149167 this ->executor = executor;
150168 this ->network = network; // May be nullptr!
@@ -245,13 +263,13 @@ void runInferenceApp(AppCliArgs *args, void (*handler)(AppInferenceContext *cont
245263 configWriter.writeToWorkers (&net.netConfig , net.nodeConfigs );
246264 }
247265
248- std::unique_ptr<NnDevice> device ( createDevice ( args, &net.netConfig , rootNodeConfig, &execution) );
249- NnExecutor executor (&net.netConfig , rootNodeConfig, device. get () , &execution, synchronizer.get (), args->benchmark );
266+ std::vector<NnExecutorDevice> devices = resolveDevices ( args, &net.netConfig , rootNodeConfig, &execution);
267+ NnExecutor executor (&net.netConfig , rootNodeConfig, &devices , &execution, synchronizer.get (), args->benchmark );
250268
251269 NnRootWeightLoader weightLoader (&executor, network, nNodes);
252270 loadLlmNetWeight (args->modelPath , &net, &weightLoader);
253271
254- RootLlmInference inference (&net, device. get (), &execution, &executor, network);
272+ RootLlmInference inference (&net, &execution, &executor, network);
255273
256274 if (network != nullptr ) {
257275 network->resetStats ();
@@ -290,10 +308,9 @@ void runWorkerApp(AppCliArgs *args) {
290308
291309 NnNetExecution execution (args->nThreads , &netConfig);
292310
293- std::unique_ptr<NnDevice> device (createDevice (args, &netConfig, &nodeConfig, &execution));
294-
311+ std::vector<NnExecutorDevice> devices = resolveDevices (args, &netConfig, &nodeConfig, &execution);
295312 NnNetworkNodeSynchronizer synchronizer (network, &execution, &netConfig, &nodeConfig);
296- NnExecutor executor (&netConfig, &nodeConfig, device. get () , &execution, &synchronizer, false );
313+ NnExecutor executor (&netConfig, &nodeConfig, &devices , &execution, &synchronizer, false );
297314
298315 NnWorkerWeightReader weightReader (&executor, network);
299316 weightReader.read ();
0 commit comments