Skip to content

Commit 55e32c5

Browse files
authored
Use semaphores in TaskUpdateRequestManager to avoid blocking task update threads (#685)
1 parent 2811256 commit 55e32c5

File tree

3 files changed

+109
-62
lines changed

3 files changed

+109
-62
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ All notable changes to this project will be documented in this file.
1818
- Move task metrics from `TaskUpdateManager` to `TaskService`. (#676)
1919
- Fail fast when tasks are detected past their contribution or final deadline. (#677)
2020
- Mitigate potential race conditions by enforcing `currentStatus` value when updating a task. (#681)
21+
- Use semaphores in `TaskUpdateRequestManager` to avoid blocking task update threads. (#685)
2122

2223
### Quality
2324

src/main/java/com/iexec/core/task/update/TaskUpdateRequestManager.java

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020 IEXEC BLOCKCHAIN TECH
2+
* Copyright 2020-2024 IEXEC BLOCKCHAIN TECH
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,13 +16,13 @@
1616

1717
package com.iexec.core.task.update;
1818

19-
import com.iexec.common.utils.ContextualLockRunner;
2019
import com.iexec.core.task.Task;
2120
import com.iexec.core.task.TaskService;
2221
import lombok.extern.slf4j.Slf4j;
22+
import net.jodah.expiringmap.ExpiringMap;
2323
import org.springframework.stereotype.Component;
2424

25-
import java.util.Optional;
25+
import java.util.concurrent.Semaphore;
2626
import java.util.concurrent.ThreadPoolExecutor;
2727
import java.util.concurrent.TimeUnit;
2828

@@ -42,8 +42,10 @@ public class TaskUpdateRequestManager {
4242
*/
4343
private static final int TASK_UPDATE_THREADS_POOL_SIZE = Runtime.getRuntime().availableProcessors() * 2;
4444

45-
private final ContextualLockRunner<String> taskExecutionLockRunner =
46-
new ContextualLockRunner<>(LONGEST_TASK_TIMEOUT.getSeconds(), TimeUnit.SECONDS);
45+
// Working with semaphore to guarantee at most 1 item in queue and 1 running thread
46+
private final ExpiringMap<String, Semaphore> taskExecutionLockRunner = ExpiringMap.builder()
47+
.expiration(LONGEST_TASK_TIMEOUT.getSeconds(), TimeUnit.SECONDS)
48+
.build();
4749

4850
final TaskUpdatePriorityBlockingQueue queue = new TaskUpdatePriorityBlockingQueue();
4951
// Both `corePoolSize` and `maximumPoolSize` should be set to `TASK_UPDATE_THREADS_POOL_SIZE`.
@@ -86,24 +88,32 @@ public synchronized boolean publishRequest(String chainTaskId) {
8688
log.debug("Request already published [chainTaskId:{}]", chainTaskId);
8789
return false;
8890
}
89-
final Optional<Task> oTask = taskService.getTaskByChainTaskId(chainTaskId);
90-
if (oTask.isEmpty()) {
91-
log.warn("No such task. [chainTaskId: {}]", chainTaskId);
91+
final Task task = taskService.getTaskByChainTaskId(chainTaskId).orElse(null);
92+
if (task == null) {
93+
log.warn("No such task [chainTaskId: {}]", chainTaskId);
9294
return false;
9395
}
9496

95-
final Task task = oTask.get();
97+
// Add semaphore to expiring map if missing
98+
taskExecutionLockRunner.putIfAbsent(chainTaskId, new Semaphore(1));
99+
96100
taskUpdateExecutor.execute(new TaskUpdate(task, this::updateTask));
97-
log.debug("Published task update request" +
98-
" [chainTaskId:{}, currentStatus:{}, contributionDeadline:{}, queueSize:{}]",
101+
log.debug("Published task update request [chainTaskId:{}, currentStatus:{}, contributionDeadline:{}, queueSize:{}]",
99102
chainTaskId, task.getCurrentStatus(), task.getContributionDeadline(), queue.size());
100103
return true;
101104
}
102105

103106
private void updateTask(String chainTaskId) {
104-
taskExecutionLockRunner.acceptWithLock(
105-
chainTaskId,
106-
taskUpdateManager::updateTask
107-
);
107+
if (!taskExecutionLockRunner.get(chainTaskId).tryAcquire()) {
108+
log.debug("Could not acquire lock for task update [chainTaskId:{}]", chainTaskId);
109+
return;
110+
}
111+
try {
112+
log.debug("Acquire lock for task update [chainTaskId:{}]", chainTaskId);
113+
taskUpdateManager.updateTask(chainTaskId);
114+
} finally {
115+
log.debug("Release lock for task update [chainTaskId:{}]", chainTaskId);
116+
taskExecutionLockRunner.get(chainTaskId).release();
117+
}
108118
}
109119
}

src/test/java/com/iexec/core/task/update/TaskUpdateRequestManagerTests.java

Lines changed: 83 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,57 @@
1+
/*
2+
* Copyright 2020-2024 IEXEC BLOCKCHAIN TECH
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
117
package com.iexec.core.task.update;
218

319
import com.iexec.core.task.Task;
420
import com.iexec.core.task.TaskService;
521
import com.iexec.core.task.TaskStatus;
22+
import lombok.extern.slf4j.Slf4j;
623
import net.jodah.expiringmap.ExpiringMap;
724
import org.assertj.core.api.Assertions;
8-
import org.awaitility.Awaitility;
925
import org.junit.jupiter.api.BeforeEach;
1026
import org.junit.jupiter.api.Test;
27+
import org.junit.jupiter.api.extension.ExtendWith;
1128
import org.mockito.InjectMocks;
1229
import org.mockito.Mock;
1330
import org.mockito.MockitoAnnotations;
31+
import org.springframework.boot.test.system.CapturedOutput;
32+
import org.springframework.boot.test.system.OutputCaptureExtension;
33+
import org.springframework.test.util.ReflectionTestUtils;
1434

1535
import java.util.*;
1636
import java.util.concurrent.*;
1737
import java.util.function.Consumer;
1838
import java.util.stream.Collectors;
1939
import java.util.stream.Stream;
2040

21-
import static org.mockito.Mockito.when;
41+
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
42+
import static org.awaitility.Awaitility.await;
43+
import static org.mockito.Mockito.*;
2244

45+
@Slf4j
46+
@ExtendWith(OutputCaptureExtension.class)
2347
class TaskUpdateRequestManagerTests {
2448

2549
public static final String CHAIN_TASK_ID = "chainTaskId";
2650

2751
@Mock
2852
private TaskService taskService;
53+
@Mock
54+
private TaskUpdateManager taskUpdateManager;
2955

3056
@InjectMocks
3157
private TaskUpdateRequestManager taskUpdateRequestManager;
@@ -37,43 +63,69 @@ void init() {
3763

3864
// region publishRequest()
3965
@Test
40-
void shouldPublishRequest() throws ExecutionException, InterruptedException {
66+
void shouldPublishRequest(CapturedOutput output) {
4167
when(taskService.getTaskByChainTaskId(CHAIN_TASK_ID))
4268
.thenReturn(Optional.of(Task.builder().chainTaskId(CHAIN_TASK_ID).build()));
4369

44-
boolean publishRequestStatus = taskUpdateRequestManager.publishRequest(CHAIN_TASK_ID);
70+
final boolean publishRequestStatus = taskUpdateRequestManager.publishRequest(CHAIN_TASK_ID);
71+
await().atMost(5L, TimeUnit.SECONDS)
72+
.until(() -> output.getOut().contains("Acquire lock for task update [chainTaskId:chainTaskId]")
73+
&& output.getOut().contains("Release lock for task update [chainTaskId:chainTaskId]"));
4574

46-
Assertions.assertThat(publishRequestStatus).isTrue();
75+
assertThat(publishRequestStatus).isTrue();
76+
verify(taskUpdateManager).updateTask(CHAIN_TASK_ID);
4777
}
4878

4979
@Test
50-
void shouldNotPublishRequestSinceEmptyTaskId() throws ExecutionException, InterruptedException {
51-
boolean publishRequestStatus = taskUpdateRequestManager.publishRequest("");
80+
void shouldPublishRequestButNotAcquireLock(CapturedOutput output) {
81+
final ExpiringMap<String, Semaphore> locks = ExpiringMap.builder()
82+
.expiration(30L, TimeUnit.SECONDS)
83+
.build();
84+
locks.putIfAbsent(CHAIN_TASK_ID, new Semaphore(1));
85+
assertThat(locks.get(CHAIN_TASK_ID).tryAcquire()).isTrue();
86+
ReflectionTestUtils.setField(taskUpdateRequestManager, "taskExecutionLockRunner", locks);
87+
when(taskService.getTaskByChainTaskId(CHAIN_TASK_ID))
88+
.thenReturn(Optional.of(Task.builder().chainTaskId(CHAIN_TASK_ID).build()));
89+
90+
final boolean publishRequestStatus = taskUpdateRequestManager.publishRequest(CHAIN_TASK_ID);
91+
await().atMost(5L, TimeUnit.SECONDS)
92+
.until(() -> output.getOut().contains("Could not acquire lock for task update [chainTaskId:chainTaskId]"));
5293

53-
Assertions.assertThat(publishRequestStatus).isFalse();
94+
assertThat(publishRequestStatus).isTrue();
95+
verifyNoInteractions(taskUpdateManager);
5496
}
5597

5698
@Test
57-
void shouldNotPublishRequestSinceItemAlreadyAdded() throws ExecutionException, InterruptedException {
99+
void shouldNotPublishRequestSinceEmptyTaskId() {
100+
final boolean publishRequestStatus = taskUpdateRequestManager.publishRequest("");
101+
102+
assertThat(publishRequestStatus).isFalse();
103+
verifyNoInteractions(taskService, taskUpdateManager);
104+
}
105+
106+
@Test
107+
void shouldNotPublishRequestSinceItemAlreadyAdded() {
58108
when(taskService.getTaskByChainTaskId(CHAIN_TASK_ID))
59109
.thenReturn(Optional.of(Task.builder().chainTaskId(CHAIN_TASK_ID).build()));
60110
taskUpdateRequestManager.queue.add(
61111
buildTaskUpdate(CHAIN_TASK_ID, null, null, null)
62112
);
63113

64-
boolean publishRequestStatus = taskUpdateRequestManager.publishRequest(CHAIN_TASK_ID);
114+
final boolean publishRequestStatus = taskUpdateRequestManager.publishRequest(CHAIN_TASK_ID);
65115

66-
Assertions.assertThat(publishRequestStatus).isFalse();
116+
assertThat(publishRequestStatus).isFalse();
117+
verifyNoInteractions(taskUpdateManager);
67118
}
68119

69120
@Test
70-
void shouldNotPublishRequestSinceTaskDoesNotExist() throws ExecutionException, InterruptedException {
121+
void shouldNotPublishRequestSinceTaskDoesNotExist() {
71122
when(taskService.getTaskByChainTaskId(CHAIN_TASK_ID))
72123
.thenReturn(Optional.empty());
73124

74-
boolean publishRequestStatus = taskUpdateRequestManager.publishRequest(CHAIN_TASK_ID);
125+
final boolean publishRequestStatus = taskUpdateRequestManager.publishRequest(CHAIN_TASK_ID);
75126

76-
Assertions.assertThat(publishRequestStatus).isFalse();
127+
assertThat(publishRequestStatus).isFalse();
128+
verifyNoInteractions(taskUpdateManager);
77129
}
78130
// endregion
79131

@@ -110,20 +162,18 @@ void shouldNotUpdateAtTheSameTime() {
110162
.collect(Collectors.toList());
111163

112164
updates.forEach(taskUpdateRequestManager.taskUpdateExecutor::execute);
113-
Awaitility
114-
.await()
115-
.timeout(30, TimeUnit.SECONDS)
165+
await().timeout(30, TimeUnit.SECONDS)
116166
.until(() -> callsOrder.size() == callsPerUpdate * updates.size());
117167

118-
Assertions.assertThat(callsOrder).hasSize(callsPerUpdate * updates.size());
168+
assertThat(callsOrder).hasSize(callsPerUpdate * updates.size());
119169

120170
// We loop through calls order and see if all calls for a given update have finished
121171
// before another update starts for this task.
122172
// Two updates for different tasks can run at the same time.
123173
Map<String, Map<Integer, Integer>> foundTaskUpdates = new HashMap<>();
124174

125175
for (int updateId : callsOrder) {
126-
System.out.println("[taskId:" + taskForUpdateId.get(updateId) + ", updateId:" + updateId + "]");
176+
log.info("[taskId:{}, updateId:{}]", taskForUpdateId.get(updateId), updateId);
127177
final Map<Integer, Integer> foundOutputsForKeyGroup = foundTaskUpdates.computeIfAbsent(taskForUpdateId.get(updateId), (key) -> new HashMap<>());
128178
for (int alreadyFound : foundOutputsForKeyGroup.keySet()) {
129179
if (!Objects.equals(alreadyFound, updateId) && foundOutputsForKeyGroup.get(alreadyFound) < callsPerUpdate) {
@@ -161,14 +211,12 @@ void shouldGetInOrderForStatus() throws InterruptedException {
161211
queue.addAll(tasks);
162212

163213
final List<TaskUpdate> prioritizedTasks = queue.takeAll();
164-
Assertions.assertThat(prioritizedTasks)
165-
.containsExactly(
166-
completedTask,
167-
consensusReachedTask,
168-
runningTask,
169-
initializedTask,
170-
initializingTask
171-
);
214+
assertThat(prioritizedTasks).containsExactly(
215+
completedTask,
216+
consensusReachedTask,
217+
runningTask,
218+
initializedTask,
219+
initializingTask);
172220
}
173221

174222
@Test
@@ -192,14 +240,8 @@ void shouldGetInOrderForContributionDeadline() throws InterruptedException {
192240
queue.addAll(tasks);
193241

194242
final List<TaskUpdate> prioritizedTasks = queue.takeAll();
195-
Assertions.assertThat(prioritizedTasks)
196-
.containsExactly(
197-
t1,
198-
t2,
199-
t3,
200-
t4,
201-
t5
202-
);
243+
assertThat(prioritizedTasks).containsExactly(
244+
t1, t2, t3, t4, t5);
203245
}
204246

205247
@Test
@@ -219,13 +261,8 @@ void shouldGetInOrderForStatusAndContributionDeadline() throws InterruptedExcept
219261
queue.addAll(tasks);
220262

221263
final List<TaskUpdate> prioritizedTasks = queue.takeAll();
222-
Assertions.assertThat(prioritizedTasks)
223-
.containsExactly(
224-
t3,
225-
t4,
226-
t1,
227-
t2
228-
);
264+
assertThat(prioritizedTasks).containsExactly(
265+
t3, t4, t1, t2);
229266
}
230267
// endregion
231268

@@ -234,12 +271,11 @@ private TaskUpdate buildTaskUpdate(String chainTaskId,
234271
Date contributionDeadline,
235272
Consumer<String> taskUpdater) {
236273
return new TaskUpdate(
237-
Task
238-
.builder()
274+
Task.builder()
239275
.chainTaskId(chainTaskId)
240-
.currentStatus(status).
241-
contributionDeadline(contributionDeadline).
242-
build(),
276+
.currentStatus(status)
277+
.contributionDeadline(contributionDeadline)
278+
.build(),
243279
taskUpdater
244280
);
245281
}

0 commit comments

Comments
 (0)