diff --git a/docs/configuration.md b/docs/configuration.md index 77bf13265..10963a95b 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -172,7 +172,7 @@ Most of those properties are designed for performance tuning. Adjusting those nu * enable_envvars_config: Enable configuring MMS through environment variables. When this option is set to "true", all the static configurations of MMS can come through environment variables as well. default: false * number_of_netty_threads: number frontend netty thread, default: number of logical processors available to the JVM. * netty_client_threads: number of backend netty thread, default: number of logical processors available to the JVM. -* default_workers_per_model: number of workers to create for each model that loaded at startup time, default: available GPUs in system or number of logical processors available to the JVM. +* default_workers_per_model: number of workers to create for each model that loaded at startup time, default: available GPUs in system, available Neuron cores in system, or number of logical processors available to the JVM. * job_queue_size: number inference jobs that frontend will queue before backend can serve, default 100. Useful in cases where certain requests take predictably longer to complete. * async_logging: enable asynchronous logging for higher throughput, log output may be delayed if this is enabled, default: false. * default_response_timeout: Timeout, in seconds, used for model's backend workers before they are deemed unresponsive and rebooted. default: 120 seconds. diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/util/ConfigManager.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/util/ConfigManager.java index ad09e2e0e..9b445128a 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/util/ConfigManager.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/util/ConfigManager.java @@ -12,6 +12,7 @@ */ package com.amazonaws.ml.mms.util; +import com.google.gson.annotations.SerializedName; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.SelfSignedCertificate; @@ -19,6 +20,8 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.Reader; import java.lang.reflect.Field; import java.net.InetAddress; import java.net.UnknownHostException; @@ -67,6 +70,7 @@ public final class ConfigManager { private static final String MMS_NETTY_CLIENT_THREADS = "netty_client_threads"; private static final String MMS_JOB_QUEUE_SIZE = "job_queue_size"; private static final String MMS_NUMBER_OF_GPU = "number_of_gpu"; + private static final String MMS_NUMBER_OF_NEURON_CORES = "number_of_neuron_cores"; private static final String MMS_ASYNC_LOGGING = "async_logging"; private static final String MMS_CORS_ALLOWED_ORIGIN = "cors_allowed_origin"; private static final String MMS_CORS_ALLOWED_METHODS = "cors_allowed_methods"; @@ -143,6 +147,13 @@ private ConfigManager(Arguments args) { getAvailableGpu(), getIntProperty(MMS_NUMBER_OF_GPU, Integer.MAX_VALUE)))); + prop.setProperty( + MMS_NUMBER_OF_NEURON_CORES, + String.valueOf( + Integer.min( + getAvailableNeuronCores(), + getIntProperty(MMS_NUMBER_OF_NEURON_CORES, Integer.MAX_VALUE)))); + String pythonExecutable = args.getPythonExecutable(); if (pythonExecutable != null) { prop.setProperty("PYTHON_EXECUTABLE", pythonExecutable); @@ -258,6 +269,10 @@ public int getNumberOfGpu() { return getIntProperty(MMS_NUMBER_OF_GPU, 0); } + public int getNumberOfNeuronCores() { + return getIntProperty(MMS_NUMBER_OF_NEURON_CORES, 0); + } + public String getMmsDefaultServiceHandler() { return getProperty(MMS_DEFAULT_SERVICE_HANDLER, null); } @@ -283,6 +298,9 @@ public int getDefaultWorkers() { if (workers == 0) { workers = getNumberOfGpu(); } + if (workers == 0) { + workers = getNumberOfNeuronCores(); + } if (workers == 0) { workers = Runtime.getRuntime().availableProcessors(); } @@ -453,6 +471,8 @@ public String dumpConfigurations() { + System.getProperty("java.io.tmpdir") + "\nNumber of GPUs: " + getNumberOfGpu() + + "\nNumber of Neuron Cores: " + + getNumberOfNeuronCores() + "\nNumber of CPUs: " + runtime.availableProcessors() + "\nMax heap size: " @@ -587,6 +607,26 @@ private static int getAvailableGpu() { } } + private static int getAvailableNeuronCores() { + try { + Process process = Runtime.getRuntime().exec("neuron-ls --json-output"); + int ret = process.waitFor(); + if (ret != 0) { + return 0; + } + Reader reader = new InputStreamReader(process.getInputStream(), StandardCharsets.UTF_8); + NeuronConfig[] results = JsonUtils.GSON.fromJson(reader, NeuronConfig[].class); + return Arrays.stream(results).mapToInt(c -> c.numNeuronCores).sum(); + } catch (IOException | InterruptedException e) { + return 0; + } + } + + private static final class NeuronConfig { + @SerializedName("nc_count") + int numNeuronCores; + } + public static final class Arguments { private String mmsConfigFile;