|
26 | 26 |
|
27 | 27 | import com.gpuopenanalytics.jenkins.remotedocker.AbstractDockerLauncher; |
28 | 28 | import hudson.Extension; |
| 29 | +import hudson.Launcher; |
29 | 30 | import hudson.model.Descriptor; |
30 | 31 | import hudson.util.ArgumentListBuilder; |
31 | 32 | import org.apache.commons.lang.StringUtils; |
32 | 33 | import org.jenkinsci.Symbol; |
33 | 34 | import org.kohsuke.stapler.DataBoundConstructor; |
34 | 35 |
|
| 36 | +import java.io.ByteArrayOutputStream; |
| 37 | +import java.io.IOException; |
| 38 | +import java.nio.charset.StandardCharsets; |
| 39 | +import java.util.Arrays; |
| 40 | +import java.util.ArrayList; |
| 41 | +import java.util.List; |
| 42 | +import java.util.regex.Matcher; |
| 43 | +import java.util.regex.Pattern; |
| 44 | + |
35 | 45 | /** |
36 | 46 | * Defines which GPU devices are visible in the container. Passes |
37 | 47 | * <code>-e NVIDIA_VISIBLE_DEVICES=value</code> |
@@ -69,7 +79,13 @@ public void addCreateArgs(AbstractDockerLauncher launcher, |
69 | 79 | ArgumentListBuilder args) { |
70 | 80 | String value; |
71 | 81 | if ("executor".equals(getValue())) { |
72 | | - value = launcher.getEnvironment().get("EXECUTOR_NUMBER"); |
| 82 | + String executorNum = launcher.getEnvironment().get("EXECUTOR_NUMBER"); |
| 83 | + String nvidiasmiOutput = executeWithOutput(launcher.getInner(), "nvidia-smi", "-L"); |
| 84 | + if (isMIG(nvidiasmiOutput)) { |
| 85 | + value = getMIG(nvidiasmiOutput, executorNum); |
| 86 | + } else { |
| 87 | + value = executorNum; |
| 88 | + } |
73 | 89 | } else { |
74 | 90 | value = getResolvedValue(launcher); |
75 | 91 | } |
@@ -99,4 +115,45 @@ public String getDisplayName() { |
99 | 115 | return "NVIDIA Device Visibility"; |
100 | 116 | } |
101 | 117 | } |
| 118 | + |
| 119 | + private String executeWithOutput(Launcher launcher, String... args) { |
| 120 | + try { |
| 121 | + ByteArrayOutputStream baos = new ByteArrayOutputStream(); |
| 122 | + int status = launcher.launch() |
| 123 | + .cmds(args) |
| 124 | + .stdout(baos) |
| 125 | + .stderr(launcher.getListener().getLogger()) |
| 126 | + .join(); |
| 127 | + if (status != 0) { |
| 128 | + throw new RuntimeException( |
| 129 | + "Non-zero status " + status + ": " + Arrays |
| 130 | + .toString(args)); |
| 131 | + } |
| 132 | + return baos.toString(StandardCharsets.UTF_8.name()).trim(); |
| 133 | + } catch (InterruptedException | IOException e) { |
| 134 | + throw new RuntimeException(e); |
| 135 | + } |
| 136 | + } |
| 137 | + |
| 138 | + private boolean isMIG(String output) { |
| 139 | + Pattern pattern = Pattern.compile("(MIG-GPU-[a-f0-9\\-\\/]+)"); |
| 140 | + Matcher m = pattern.matcher(output); |
| 141 | + |
| 142 | + if (m.find()) { |
| 143 | + return true; |
| 144 | + } |
| 145 | + return false; |
| 146 | + } |
| 147 | + |
| 148 | + private String getMIG(String output, String executor) { |
| 149 | + int executorNum = Integer.parseInt(executor); |
| 150 | + List<String> uuids = new ArrayList<String>(); |
| 151 | + Pattern pattern = Pattern.compile("(MIG-GPU-[a-f0-9\\-\\/]+)"); |
| 152 | + Matcher m = pattern.matcher(output); |
| 153 | + |
| 154 | + while (m.find()) { |
| 155 | + uuids.add(m.group()); |
| 156 | + } |
| 157 | + return uuids.get(executorNum); |
| 158 | + } |
102 | 159 | } |
0 commit comments