Skip to content

Commit 37e095e

Browse files
committed
feat: execute pre-compute container for a TEE task requesting a bulk processing
1 parent 1f0bb7f commit 37e095e

File tree

2 files changed

+68
-85
lines changed

2 files changed

+68
-85
lines changed

src/main/java/com/iexec/worker/compute/pre/PreComputeService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ public PreComputeResponse runTeePreCompute(TaskDescription taskDescription, Work
139139
}
140140

141141
// run TEE pre-compute container if needed
142-
if (taskDescription.containsDataset() || taskDescription.containsInputFiles()) {
142+
if (taskDescription.requiresPreCompute()) {
143143
log.info("Task contains TEE input data [chainTaskId:{}, containsDataset:{}, containsInputFiles:{}]",
144144
chainTaskId, taskDescription.containsDataset(), taskDescription.containsInputFiles());
145145
final ReplicateStatusCause exitCause = downloadDatasetAndFiles(taskDescription, secureSession);

src/test/java/com/iexec/worker/compute/pre/PreComputeServiceTests.java

Lines changed: 67 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
import com.iexec.worker.tee.TeeService;
4343
import com.iexec.worker.tee.TeeServicesManager;
4444
import com.iexec.worker.tee.TeeServicesPropertiesService;
45-
import org.assertj.core.api.Assertions;
4645
import org.junit.jupiter.api.BeforeEach;
4746
import org.junit.jupiter.api.Test;
4847
import org.junit.jupiter.params.ParameterizedTest;
@@ -58,7 +57,7 @@
5857

5958
import static com.iexec.common.replicate.ReplicateStatusCause.*;
6059
import static com.iexec.sms.api.TeeSessionGenerationError.*;
61-
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
60+
import static org.assertj.core.api.Assertions.assertThat;
6261
import static org.mockito.ArgumentMatchers.any;
6362
import static org.mockito.Mockito.*;
6463

@@ -70,6 +69,7 @@ class PreComputeServiceTests {
7069
private static final String PRE_COMPUTE_ENTRYPOINT = "preComputeEntrypoint";
7170
private final String chainTaskId = "chainTaskId";
7271
private final String datasetUri = "datasetUri";
72+
private final String network = "network";
7373
private final TaskDescription.TaskDescriptionBuilder taskDescriptionBuilder = TaskDescription.builder()
7474
.chainTaskId(chainTaskId)
7575
.datasetAddress("datasetAddress")
@@ -130,13 +130,7 @@ void beforeEach() {
130130
}
131131

132132
//region runTeePreCompute
133-
@Test
134-
void shouldRunTeePreComputeAndPrepareInputDataWhenDatasetAndInputFilesArePresent() throws TeeSessionGenerationException {
135-
final DealParams dealParams = DealParams.builder()
136-
.iexecInputFiles(List.of("input-file1"))
137-
.build();
138-
final TaskDescription taskDescription = taskDescriptionBuilder.dealParams(dealParams).build();
139-
133+
void prepareMockWhenPreComputeShouldRunForTask(final TaskDescription taskDescription) throws TeeSessionGenerationException {
140134
when(smsService.createTeeSession(workerpoolAuthorization)).thenReturn(secureSession);
141135
when(preComputeProperties.getImage()).thenReturn(PRE_COMPUTE_IMAGE);
142136
when(preComputeProperties.getHeapSizeInBytes()).thenReturn(PRE_COMPUTE_HEAP);
@@ -146,63 +140,54 @@ void shouldRunTeePreComputeAndPrepareInputDataWhenDatasetAndInputFilesArePresent
146140
when(teeMockedService.buildPreComputeDockerEnv(taskDescription, secureSession))
147141
.thenReturn(List.of("env"));
148142
when(dockerService.getInputBind(chainTaskId)).thenReturn(IEXEC_IN_BIND);
149-
String network = "network";
150143
when(workerConfigService.getDockerNetworkName()).thenReturn(network);
151144
when(dockerService.run(any())).thenReturn(DockerRunResponse.builder()
152145
.containerExitCode(0)
153146
.finalStatus(DockerRunFinalStatus.SUCCESS)
154147
.executionDuration(Duration.ofSeconds(10))
155148
.build());
156149
when(sgxService.getSgxDriverMode()).thenReturn(SgxDriverMode.LEGACY);
150+
}
157151

158-
Assertions.assertThat(taskDescription.containsDataset()).isTrue();
159-
Assertions.assertThat(taskDescription.containsInputFiles()).isTrue();
160-
Assertions.assertThat(preComputeService
161-
.runTeePreCompute(taskDescription, workerpoolAuthorization))
162-
.isEqualTo(PreComputeResponse.builder().secureSession(secureSession).build());
152+
void verifyDockerRun() {
163153
verify(dockerService).run(captor.capture());
164154
DockerRunRequest capturedRequest = captor.getValue();
165-
Assertions.assertThat(capturedRequest.getImageUri()).isEqualTo(PRE_COMPUTE_IMAGE);
166-
Assertions.assertThat(capturedRequest.getEntrypoint()).isEqualTo(PRE_COMPUTE_ENTRYPOINT);
167-
Assertions.assertThat(capturedRequest.getSgxDriverMode()).isEqualTo(SgxDriverMode.LEGACY);
168-
Assertions.assertThat(capturedRequest.getHostConfig().getNetworkMode()).isEqualTo(network);
169-
Assertions.assertThat(capturedRequest.getHostConfig().getBinds()[0]).hasToString(IEXEC_IN_BIND + ":rw");
155+
assertThat(capturedRequest.getImageUri()).isEqualTo(PRE_COMPUTE_IMAGE);
156+
assertThat(capturedRequest.getEntrypoint()).isEqualTo(PRE_COMPUTE_ENTRYPOINT);
157+
assertThat(capturedRequest.getSgxDriverMode()).isEqualTo(SgxDriverMode.LEGACY);
158+
assertThat(capturedRequest.getHostConfig().getNetworkMode()).isEqualTo(network);
159+
assertThat(capturedRequest.getHostConfig().getBinds()[0]).hasToString(IEXEC_IN_BIND + ":rw");
160+
}
161+
162+
@Test
163+
void shouldRunTeePreComputeAndPrepareInputDataWhenDatasetAndInputFilesArePresent() throws TeeSessionGenerationException {
164+
final DealParams dealParams = DealParams.builder()
165+
.iexecInputFiles(List.of("input-file1"))
166+
.build();
167+
final TaskDescription taskDescription = taskDescriptionBuilder.dealParams(dealParams).build();
168+
169+
prepareMockWhenPreComputeShouldRunForTask(taskDescription);
170+
171+
assertThat(taskDescription.containsDataset()).isTrue();
172+
assertThat(taskDescription.containsInputFiles()).isTrue();
173+
assertThat(taskDescription.isBulkRequest()).isFalse();
174+
assertThat(preComputeService.runTeePreCompute(taskDescription, workerpoolAuthorization))
175+
.isEqualTo(PreComputeResponse.builder().secureSession(secureSession).build());
176+
verifyDockerRun();
170177
}
171178

172179
@Test
173180
void shouldRunTeePreComputeAndPrepareInputDataWhenOnlyDatasetIsPresent() throws TeeSessionGenerationException {
174181
final TaskDescription taskDescription = taskDescriptionBuilder.build();
175182

176-
when(smsService.createTeeSession(workerpoolAuthorization)).thenReturn(secureSession);
177-
when(preComputeProperties.getImage()).thenReturn(PRE_COMPUTE_IMAGE);
178-
when(preComputeProperties.getHeapSizeInBytes()).thenReturn(PRE_COMPUTE_HEAP);
179-
when(preComputeProperties.getEntrypoint()).thenReturn(PRE_COMPUTE_ENTRYPOINT);
180-
when(dockerClientInstanceMock.isImagePresent(PRE_COMPUTE_IMAGE))
181-
.thenReturn(true);
182-
when(teeMockedService.buildPreComputeDockerEnv(taskDescription, secureSession))
183-
.thenReturn(List.of("env"));
184-
when(dockerService.getInputBind(chainTaskId)).thenReturn(IEXEC_IN_BIND);
185-
String network = "network";
186-
when(workerConfigService.getDockerNetworkName()).thenReturn(network);
187-
when(dockerService.run(any())).thenReturn(DockerRunResponse.builder()
188-
.containerExitCode(0)
189-
.finalStatus(DockerRunFinalStatus.SUCCESS)
190-
.executionDuration(Duration.ofSeconds(10))
191-
.build());
192-
when(sgxService.getSgxDriverMode()).thenReturn(SgxDriverMode.LEGACY);
183+
prepareMockWhenPreComputeShouldRunForTask(taskDescription);
193184

194-
Assertions.assertThat(taskDescription.containsDataset()).isTrue();
195-
Assertions.assertThat(taskDescription.containsInputFiles()).isFalse();
196-
Assertions.assertThat(preComputeService
197-
.runTeePreCompute(taskDescription, workerpoolAuthorization))
185+
assertThat(taskDescription.containsDataset()).isTrue();
186+
assertThat(taskDescription.containsInputFiles()).isFalse();
187+
assertThat(taskDescription.isBulkRequest()).isFalse();
188+
assertThat(preComputeService.runTeePreCompute(taskDescription, workerpoolAuthorization))
198189
.isEqualTo(PreComputeResponse.builder().secureSession(secureSession).build());
199-
verify(dockerService).run(captor.capture());
200-
DockerRunRequest capturedRequest = captor.getValue();
201-
Assertions.assertThat(capturedRequest.getImageUri()).isEqualTo(PRE_COMPUTE_IMAGE);
202-
Assertions.assertThat(capturedRequest.getEntrypoint()).isEqualTo(PRE_COMPUTE_ENTRYPOINT);
203-
Assertions.assertThat(capturedRequest.getSgxDriverMode()).isEqualTo(SgxDriverMode.LEGACY);
204-
Assertions.assertThat(capturedRequest.getHostConfig().getNetworkMode()).isEqualTo(network);
205-
Assertions.assertThat(capturedRequest.getHostConfig().getBinds()[0]).hasToString(IEXEC_IN_BIND + ":rw");
190+
verifyDockerRun();
206191
}
207192

208193

@@ -216,36 +201,34 @@ void shouldRunTeePreComputeAndPrepareInputDataWhenOnlyInputFilesArePresent() thr
216201
.dealParams(dealParams)
217202
.build();
218203

219-
when(smsService.createTeeSession(workerpoolAuthorization)).thenReturn(secureSession);
220-
when(preComputeProperties.getImage()).thenReturn(PRE_COMPUTE_IMAGE);
221-
when(preComputeProperties.getHeapSizeInBytes()).thenReturn(PRE_COMPUTE_HEAP);
222-
when(preComputeProperties.getEntrypoint()).thenReturn(PRE_COMPUTE_ENTRYPOINT);
223-
when(dockerClientInstanceMock.isImagePresent(PRE_COMPUTE_IMAGE))
224-
.thenReturn(true);
225-
when(teeMockedService.buildPreComputeDockerEnv(taskDescription, secureSession))
226-
.thenReturn(List.of("env"));
227-
when(dockerService.getInputBind(chainTaskId)).thenReturn(IEXEC_IN_BIND);
228-
String network = "network";
229-
when(workerConfigService.getDockerNetworkName()).thenReturn(network);
230-
when(dockerService.run(any())).thenReturn(DockerRunResponse.builder()
231-
.containerExitCode(0)
232-
.finalStatus(DockerRunFinalStatus.SUCCESS)
233-
.executionDuration(Duration.ofSeconds(10))
234-
.build());
235-
when(sgxService.getSgxDriverMode()).thenReturn(SgxDriverMode.LEGACY);
204+
prepareMockWhenPreComputeShouldRunForTask(taskDescription);
236205

237-
Assertions.assertThat(taskDescription.containsDataset()).isFalse();
238-
Assertions.assertThat(taskDescription.containsInputFiles()).isTrue();
239-
Assertions.assertThat(preComputeService
240-
.runTeePreCompute(taskDescription, workerpoolAuthorization))
206+
assertThat(taskDescription.containsDataset()).isFalse();
207+
assertThat(taskDescription.containsInputFiles()).isTrue();
208+
assertThat(taskDescription.isBulkRequest()).isFalse();
209+
assertThat(preComputeService.runTeePreCompute(taskDescription, workerpoolAuthorization))
241210
.isEqualTo(PreComputeResponse.builder().secureSession(secureSession).build());
242-
verify(dockerService).run(captor.capture());
243-
DockerRunRequest capturedRequest = captor.getValue();
244-
Assertions.assertThat(capturedRequest.getImageUri()).isEqualTo(PRE_COMPUTE_IMAGE);
245-
Assertions.assertThat(capturedRequest.getEntrypoint()).isEqualTo(PRE_COMPUTE_ENTRYPOINT);
246-
Assertions.assertThat(capturedRequest.getSgxDriverMode()).isEqualTo(SgxDriverMode.LEGACY);
247-
Assertions.assertThat(capturedRequest.getHostConfig().getNetworkMode()).isEqualTo(network);
248-
Assertions.assertThat(capturedRequest.getHostConfig().getBinds()[0]).hasToString(IEXEC_IN_BIND + ":rw");
211+
verifyDockerRun();
212+
}
213+
214+
@Test
215+
void shouldRunTeePreComputeAndPrepareInputDataWhenBulkProcessingRequested() throws TeeSessionGenerationException {
216+
final DealParams dealParams = DealParams.builder()
217+
.bulkCid("bulk_cid")
218+
.build();
219+
final TaskDescription taskDescription = taskDescriptionBuilder
220+
.datasetAddress("")
221+
.dealParams(dealParams)
222+
.build();
223+
224+
prepareMockWhenPreComputeShouldRunForTask(taskDescription);
225+
226+
assertThat(taskDescription.containsDataset()).isFalse();
227+
assertThat(taskDescription.containsInputFiles()).isFalse();
228+
assertThat(taskDescription.isBulkRequest()).isTrue();
229+
assertThat(preComputeService.runTeePreCompute(taskDescription, workerpoolAuthorization))
230+
.isEqualTo(PreComputeResponse.builder().secureSession(secureSession).build());
231+
verifyDockerRun();
249232
}
250233

251234
@Test
@@ -307,8 +290,8 @@ void shouldNotRunTeePreComputeSinceDockerImageNotFoundLocally() throws TeeSessio
307290
.thenReturn(false);
308291

309292
final PreComputeResponse preComputeResponse = preComputeService.runTeePreCompute(taskDescription, workerpoolAuthorization);
310-
Assertions.assertThat(preComputeResponse.isSuccessful()).isFalse();
311-
Assertions.assertThat(preComputeResponse.getExitCause()).isEqualTo(ReplicateStatusCause.PRE_COMPUTE_IMAGE_MISSING);
293+
assertThat(preComputeResponse.isSuccessful()).isFalse();
294+
assertThat(preComputeResponse.getExitCause()).isEqualTo(ReplicateStatusCause.PRE_COMPUTE_IMAGE_MISSING);
312295
verify(dockerService, never()).run(any());
313296
}
314297

@@ -336,9 +319,9 @@ void shouldFailToRunTeePreComputeSinceDockerRunFailed(Map.Entry<Integer, Replica
336319
PreComputeResponse preComputeResponse =
337320
preComputeService.runTeePreCompute(taskDescription, workerpoolAuthorization);
338321

339-
Assertions.assertThat(preComputeResponse.isSuccessful())
322+
assertThat(preComputeResponse.isSuccessful())
340323
.isFalse();
341-
Assertions.assertThat(preComputeResponse.getExitCause())
324+
assertThat(preComputeResponse.getExitCause())
342325
.isEqualTo(exitCodeKeyToExpectedCauseValue.getValue());
343326
verify(dockerService).run(any());
344327
}
@@ -372,9 +355,9 @@ void shouldFailToRunTeePreComputeSinceTimeout() throws TeeSessionGenerationExcep
372355
PreComputeResponse preComputeResponse =
373356
preComputeService.runTeePreCompute(taskDescription, workerpoolAuthorization);
374357

375-
Assertions.assertThat(preComputeResponse.isSuccessful())
358+
assertThat(preComputeResponse.isSuccessful())
376359
.isFalse();
377-
Assertions.assertThat(preComputeResponse.getExitCause())
360+
assertThat(preComputeResponse.getExitCause())
378361
.isEqualTo(ReplicateStatusCause.PRE_COMPUTE_TIMEOUT);
379362
verify(dockerService).run(any());
380363
}
@@ -425,14 +408,14 @@ static Stream<Arguments> teeSessionGenerationErrorMap() {
425408
@ParameterizedTest
426409
@MethodSource("teeSessionGenerationErrorMap")
427410
void shouldConvertTeeSessionGenerationError(TeeSessionGenerationError error, ReplicateStatusCause expectedCause) {
428-
Assertions.assertThat(preComputeService.teeSessionGenerationErrorToReplicateStatusCause(error))
411+
assertThat(preComputeService.teeSessionGenerationErrorToReplicateStatusCause(error))
429412
.isEqualTo(expectedCause);
430413
}
431414

432415
@Test
433416
void shouldAllTeeSessionGenerationErrorHaveMatch() {
434417
for (TeeSessionGenerationError error : TeeSessionGenerationError.values()) {
435-
Assertions.assertThat(preComputeService.teeSessionGenerationErrorToReplicateStatusCause(error))
418+
assertThat(preComputeService.teeSessionGenerationErrorToReplicateStatusCause(error))
436419
.isNotNull();
437420
}
438421
}

0 commit comments

Comments
 (0)