From f14b23e04703d256348909b7334272be60abde0a Mon Sep 17 00:00:00 2001 From: Hoxo <29162640+Hoxo@users.noreply.github.com> Date: Sat, 27 May 2023 23:43:59 +0300 Subject: [PATCH 1/9] task queue draft --- .../lzy/allocator/alloc/AllocateVmAction.java | 2 +- .../allocator/alloc/DeleteSessionAction.java | 8 - .../alloc/MountDynamicDiskAction.java | 2 +- .../alloc/UnmountDynamicDiskAction.java | 2 +- .../db/allocator/migrations/V8__task.sql | 21 ++ lzy/long-running/pom.xml | 5 + .../lzy/longrunning/OperationRunnerBase.java | 13 +- .../ai/lzy/longrunning/task/DaoTaskQueue.java | 100 ++++++++ .../java/ai/lzy/longrunning/task/Task.java | 74 ++++++ .../lzy/longrunning/task/TaskAwareAction.java | 58 +++++ .../ai/lzy/longrunning/task/TaskExecutor.java | 63 +++++ .../ai/lzy/longrunning/task/TaskQueue.java | 19 ++ .../ai/lzy/longrunning/task/TaskResolver.java | 5 + .../ai/lzy/longrunning/task/dao/TaskDao.java | 29 +++ .../lzy/longrunning/task/dao/TaskDaoImpl.java | 227 ++++++++++++++++++ .../ai/lzy/longrunning/TaskDaoImplTest.java | 139 +++++++++++ .../test/resources/db/migrations/V1__task.sql | 18 ++ .../src/test/resources/log4j2.yaml | 28 +++ .../lzy/model/db/test/DatabaseTestUtils.java | 23 ++ 19 files changed, 823 insertions(+), 13 deletions(-) create mode 100644 lzy/allocator/src/main/resources/db/allocator/migrations/V8__task.sql create mode 100644 lzy/long-running/src/main/java/ai/lzy/longrunning/task/DaoTaskQueue.java create mode 100644 lzy/long-running/src/main/java/ai/lzy/longrunning/task/Task.java create mode 100644 lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskAwareAction.java create mode 100644 lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskExecutor.java create mode 100644 lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskQueue.java create mode 100644 lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskResolver.java create mode 100644 lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/TaskDao.java create mode 100644 lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/TaskDaoImpl.java create mode 100644 lzy/long-running/src/test/java/ai/lzy/longrunning/TaskDaoImplTest.java create mode 100644 lzy/long-running/src/test/resources/db/migrations/V1__task.sql create mode 100644 lzy/long-running/src/test/resources/log4j2.yaml diff --git a/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/AllocateVmAction.java b/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/AllocateVmAction.java index 0f1f0befc2..bb6625d4e8 100644 --- a/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/AllocateVmAction.java +++ b/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/AllocateVmAction.java @@ -67,7 +67,7 @@ protected void notifyExpired() { } @Override - protected void notifyFinished() { + protected void notifyFinished(@Nullable Throwable t) { allocationContext.metrics().runningAllocations.labels(vm.poolLabel()).dec(); if (deleteVmAction != null) { diff --git a/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/DeleteSessionAction.java b/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/DeleteSessionAction.java index 1e7ed69067..e77d8d09d4 100644 --- a/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/DeleteSessionAction.java +++ b/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/DeleteSessionAction.java @@ -46,14 +46,6 @@ protected boolean isInjectedError(Error e) { return e instanceof InjectedFailures.TerminateException; } - @Override - protected void notifyExpired() { - } - - @Override - protected void notifyFinished() { - } - @Override protected void onExpired(TransactionHandle tx) { throw new RuntimeException("Unexpected, sessionId: '%s', op: '%s'".formatted(sessionId, id())); diff --git a/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/MountDynamicDiskAction.java b/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/MountDynamicDiskAction.java index 97378fd0bd..f097e53c79 100644 --- a/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/MountDynamicDiskAction.java +++ b/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/MountDynamicDiskAction.java @@ -68,7 +68,7 @@ protected boolean isInjectedError(Error e) { } @Override - protected void notifyFinished() { + protected void notifyFinished(@Nullable Throwable t) { if (unmountAction != null) { log().error("{} Failed to mount dynamic disk", logPrefix()); try { diff --git a/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/UnmountDynamicDiskAction.java b/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/UnmountDynamicDiskAction.java index 8294a2aec9..84a8d5c064 100644 --- a/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/UnmountDynamicDiskAction.java +++ b/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/UnmountDynamicDiskAction.java @@ -59,7 +59,7 @@ public static String description(@Nullable Vm vm, DynamicMount mount) { } @Override - protected void notifyFinished() { + protected void notifyFinished(@Nullable Throwable t) { log().info("{} Finished unmounting volume {}", logPrefix(), dynamicMount.id()); } diff --git a/lzy/allocator/src/main/resources/db/allocator/migrations/V8__task.sql b/lzy/allocator/src/main/resources/db/allocator/migrations/V8__task.sql new file mode 100644 index 0000000000..7a5e930a57 --- /dev/null +++ b/lzy/allocator/src/main/resources/db/allocator/migrations/V8__task.sql @@ -0,0 +1,21 @@ +CREATE TYPE task_status AS ENUM ('PENDING', 'RUNNING', 'FAILED', 'FINISHED'); + +CREATE TYPE task_type AS ENUM ('UNMOUNT', 'MOUNT'); + +CREATE TABLE IF NOT EXISTS task( + id BIGSERIAL NOT NULL, + name TEXT NOT NULL, + entity_id TEXT NOT NULL, + type task_type NOT NULL, + status task_status NOT NULL, + created_at TIMESTAMP NOT NULL, + updated_at TIMESTAMP NOT NULL, + metadata JSONB NOT NULL, + operation_id TEXT, + worker_id TEXT, + lease_till TIMESTAMP, + PRIMARY KEY (id), + FOREIGN KEY (operation_id) REFERENCES operation(id) +); + +CREATE INDEX IF NOT EXISTS task_status_entity_id_idx ON task(status, entity_id, id); diff --git a/lzy/long-running/pom.xml b/lzy/long-running/pom.xml index 7930f9992d..f295c97eb7 100644 --- a/lzy/long-running/pom.xml +++ b/lzy/long-running/pom.xml @@ -45,6 +45,11 @@ junit test + + io.zonky.test + embedded-postgres + test + diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/OperationRunnerBase.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/OperationRunnerBase.java index 434392b044..15d4e70543 100644 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/OperationRunnerBase.java +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/OperationRunnerBase.java @@ -55,6 +55,7 @@ protected final void execute() { } for (var step : steps()) { + beforeStep(); final var stepResult = step.get(); switch (stepResult.code()) { case ALREADY_DONE -> { } @@ -83,7 +84,7 @@ protected final void execute() { } } } catch (Throwable e) { - notifyFinished(); + notifyFinished(e); if (e instanceof Error err && isInjectedError(err)) { log.error("{} Terminated by InjectedFailure exception: {}", logPrefix, e.getMessage()); } else { @@ -98,6 +99,10 @@ protected final void execute() { } } + protected void beforeStep() { + + } + protected Map prepareLogContext() { var ctx = super.prepareLogContext(); ctx.put(LogContextKey.OPERATION_ID, id); @@ -272,7 +277,11 @@ protected void onNotFound(@Nullable TransactionHandle tx) throws SQLException { protected void onCompletedOutside(Operation op, @Nullable TransactionHandle tx) throws SQLException { } - protected void notifyFinished() { + private void notifyFinished() { + notifyFinished(null); + } + + protected void notifyFinished(@Nullable Throwable t) { } protected final void failOperation(Status status, @Nullable TransactionHandle tx) throws SQLException { diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/DaoTaskQueue.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/DaoTaskQueue.java new file mode 100644 index 0000000000..0183170792 --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/DaoTaskQueue.java @@ -0,0 +1,100 @@ +package ai.lzy.longrunning.task; + +import ai.lzy.longrunning.task.dao.TaskDao; +import com.google.common.collect.Queues; +import jakarta.annotation.Nullable; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.time.Duration; +import java.util.List; +import java.util.Queue; + +import static ai.lzy.model.db.DbHelper.withRetries; + +public class DaoTaskQueue implements TaskQueue { + + private static final Logger LOG = LogManager.getLogger(TaskExecutor.class); + + private final TaskDao taskDao; + private final int maxQueueSize; + private final Duration leaseTime; + private final String instanceId; + private final Queue queue; + + public DaoTaskQueue(TaskDao taskDao, int maxQueueSize, Duration leaseTime, String instanceId) { + this.taskDao = taskDao; + this.maxQueueSize = maxQueueSize; + this.leaseTime = leaseTime; + this.instanceId = instanceId; + this.queue = Queues.newConcurrentLinkedQueue(); + } + + private void loadNextBatch() { + var toLoad = capacity(); + if (toLoad > 0) { + var tasks = loadPendingTasks(toLoad); + queue.addAll(tasks); + } + } + + private int capacity() { + return maxQueueSize - queue.size(); + } + + private List loadPendingTasks(int toLoad) { + try { + return withRetries(LOG, () -> taskDao.lockPendingBatch(instanceId, leaseTime, toLoad, null)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public void add(Task task) { + try { + withRetries(LOG, () -> taskDao.insert(task, null)); + } catch (Exception e) { + throw new RuntimeException(); + } + } + + @Nullable + @Override + public Task pollNext() { + var next = queue.poll(); + if (next == null) { + loadNextBatch(); + next = queue.poll(); + } + return next; + } + + public List pollRemaining() { + if (!queue.isEmpty()) { + var result = queue.stream().toList(); + queue.clear(); + return result; + } + //don't load to queue, just return loaded tasks + return loadPendingTasks(capacity()); + } + + @Override + public void update(long id, Task.Update update) { + try { + withRetries(LOG, () -> taskDao.update(id, update, null)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public void updateLease(Task task, Duration duration) { + try { + withRetries(LOG, () -> taskDao.updateLease(task.id(), duration, null)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/Task.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/Task.java new file mode 100644 index 0000000000..5b7d27cbc7 --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/Task.java @@ -0,0 +1,74 @@ +package ai.lzy.longrunning.task; + +import jakarta.annotation.Nullable; + +import java.time.Instant; +import java.util.Map; + +public record Task( + long id, + String name, + String entityId, + String type, + Status status, + Instant createdAt, + Instant updatedAt, + Map metadata, + @Nullable + String operationId, + @Nullable + String workerId, + @Nullable + Instant leaseTill +) { + public static Task createPending(String name, String entityId, String type, Map metadata) { + return new Task(-1, name, entityId, type, Status.PENDING, Instant.now(), Instant.now(), + metadata, null, null, null); + } + + public enum Status { + PENDING, + RUNNING, + FAILED, + FINISHED, + } + + public record Update( + @Nullable Status status, + @Nullable Map metadata, + @Nullable String operationId + ) { + public boolean isEmpty() { + return status == null && metadata == null && operationId == null; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private Status status; + private Map metadata; + private String operationId; + + public Builder status(Status status) { + this.status = status; + return this; + } + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public Builder operationId(String operationId) { + this.operationId = operationId; + return this; + } + + public Update build() { + return new Update(status, metadata, operationId); + } + } + } +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskAwareAction.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskAwareAction.java new file mode 100644 index 0000000000..6c9d97fdd6 --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskAwareAction.java @@ -0,0 +1,58 @@ +package ai.lzy.longrunning.task; + +import ai.lzy.longrunning.OperationRunnerBase; +import ai.lzy.longrunning.OperationsExecutor; +import ai.lzy.longrunning.dao.OperationDao; +import ai.lzy.model.db.Storage; +import jakarta.annotation.Nullable; + +import java.time.Duration; +import java.util.Map; + +public abstract class TaskAwareAction extends OperationRunnerBase { + private final Task task; + private final TaskQueue queue; + private final Duration leaseDuration; + + public TaskAwareAction(Task task, TaskQueue queue, Duration leaseDuration, String opId, String desc, + Storage storage, OperationDao operationsDao, OperationsExecutor executor) + { + super(opId, desc, storage, operationsDao, executor); + this.task = task; + this.queue = queue; + this.leaseDuration = leaseDuration; + } + + @Override + protected Map prepareLogContext() { + var ctx = super.prepareLogContext(); + ctx.put("task_id", String.valueOf(task.id())); + ctx.put("task_type", task.type()); + ctx.put("task_name", task.name()); + ctx.put("task_entity_id", task.entityId()); + return ctx; + } + + protected Task task() { + return task; + } + + @Override + protected void beforeStep() { + super.beforeStep(); + queue.updateLease(task, leaseDuration); + } + + @Override + protected void notifyFinished(@Nullable Throwable t) { + super.notifyFinished(t); + + var builder = Task.Update.builder(); + if (t != null) { + builder.status(Task.Status.FAILED); + } else { + builder.status(Task.Status.FINISHED); + } + queue.update(task.id(), builder.build()); + } +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskExecutor.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskExecutor.java new file mode 100644 index 0000000000..f0f07eff7a --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskExecutor.java @@ -0,0 +1,63 @@ +package ai.lzy.longrunning.task; + +import ai.lzy.longrunning.OperationsExecutor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.time.Duration; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +public class TaskExecutor { + + private static final Logger LOG = LogManager.getLogger(TaskExecutor.class); + + private final TaskQueue taskQueue; + private final OperationsExecutor operationsExecutor; + private final TaskResolver resolver; + private final ScheduledExecutorService scheduler; + private final Duration initialDelay; + private final Duration executionDelay; + + private volatile boolean started = false; + private volatile boolean disabled = false; + + public TaskExecutor(TaskQueue taskQueue, OperationsExecutor operationsExecutor, TaskResolver resolver, + Duration initialDelay, Duration executionDelay) + { + this.taskQueue = taskQueue; + this.operationsExecutor = operationsExecutor; + this.resolver = resolver; + this.initialDelay = initialDelay; + this.executionDelay = executionDelay; + this.scheduler = Executors.newSingleThreadScheduledExecutor(); + + } + + //todo support worker failures - retry locked tasks after restart + public void start() { + if (started) { + throw new IllegalStateException("Task executor has already started!"); + } + started = true; + scheduler.scheduleWithFixedDelay(() -> { + try { + for (Task task : taskQueue.pollRemaining()) { + if (disabled) { + return; + } + var resolvedAction = resolver.resolve(task); + operationsExecutor.startNew(resolvedAction); + } + } catch (Exception e) { + LOG.error("Got exception while scheduling task", e); + } + }, initialDelay.toMillis(), executionDelay.toMillis(), TimeUnit.MILLISECONDS); + } + + public void shutdown() { + disabled = true; + scheduler.shutdown(); + } +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskQueue.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskQueue.java new file mode 100644 index 0000000000..f76e58d592 --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskQueue.java @@ -0,0 +1,19 @@ +package ai.lzy.longrunning.task; + +import jakarta.annotation.Nullable; + +import java.time.Duration; +import java.util.List; + +public interface TaskQueue { + void add(Task task); + + @Nullable + Task pollNext(); + + List pollRemaining(); + + void update(long id, Task.Update update); + + void updateLease(Task task, Duration duration); +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskResolver.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskResolver.java new file mode 100644 index 0000000000..8193886c51 --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskResolver.java @@ -0,0 +1,5 @@ +package ai.lzy.longrunning.task; + +public interface TaskResolver { + TaskAwareAction resolve(Task task); +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/TaskDao.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/TaskDao.java new file mode 100644 index 0000000000..31a92ee5d6 --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/TaskDao.java @@ -0,0 +1,29 @@ +package ai.lzy.longrunning.task.dao; + +import ai.lzy.longrunning.task.Task; +import ai.lzy.model.db.TransactionHandle; +import jakarta.annotation.Nullable; + +import java.sql.SQLException; +import java.time.Duration; +import java.util.List; + +public interface TaskDao { + + @Nullable + Task get(long id, @Nullable TransactionHandle tx) throws SQLException; + + @Nullable + Task update(long id, Task.Update update, @Nullable TransactionHandle tx) throws SQLException; + + @Nullable + Task updateLease(long id, Duration duration, @Nullable TransactionHandle tx) throws SQLException; + + @Nullable + Task insert(Task task, @Nullable TransactionHandle tx) throws SQLException; + + void delete(long id, @Nullable TransactionHandle tx) throws SQLException; + + List lockPendingBatch(String ownerId, Duration leaseTime, int batchSize, @Nullable TransactionHandle tx) + throws SQLException; +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/TaskDaoImpl.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/TaskDaoImpl.java new file mode 100644 index 0000000000..53065a517c --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/TaskDaoImpl.java @@ -0,0 +1,227 @@ +package ai.lzy.longrunning.task.dao; + +import ai.lzy.longrunning.task.Task; +import ai.lzy.model.db.DbOperation; +import ai.lzy.model.db.Storage; +import ai.lzy.model.db.TransactionHandle; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import jakarta.annotation.Nullable; + +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public class TaskDaoImpl implements TaskDao { + + private static final String FIELDS = "id, name, entity_id, type, status, created_at, updated_at, metadata," + + " operation_id, worker_id, lease_till"; + public static final String SELECT_QUERY = "SELECT %s FROM task WHERE id = ?".formatted(FIELDS); + public static final String INSERT_QUERY = """ + INSERT INTO task (name, entity_id, type, status, created_at, updated_at, metadata, operation_id) + VALUES (?, ?, cast(? as task_type), cast(? as task_status), now(), now(), cast(? as jsonb), ?) + RETURNING %s; + """.formatted(FIELDS); + + //in first nested request we gather all tasks that either locked or free. + //in second nested request we filter result of previous request to get only free tasks + // and select only specific amount. + private static final String LOCK_PENDING_BATCH_QUERY = """ + UPDATE task + SET status = 'RUNNING', worker_id = ?, updated_at = now(), lease_till = now() + cast(? as interval) + WHERE id IN ( + SELECT id + FROM task + WHERE id IN ( + SELECT DISTINCT ON (entity_id) id + FROM task + WHERE status IN ('PENDING', 'RUNNING') + ORDER BY entity_id, id + ) AND status = 'PENDING' + LIMIT ? + ) + RETURNING %s; + """.formatted(FIELDS); + public static final String DELETE_QUERY = "DELETE FROM task WHERE id = ?"; + public static final String UPDATE_LEASE_QUERY = """ + UPDATE task + SET lease_till = now() + cast(? as interval) + WHERE id = ? + RETURNING %s + """.formatted(FIELDS); + private static final TypeReference> MAP_TYPE_REFERENCE = new TypeReference<>() { }; + + + private final Storage storage; + private final ObjectMapper objectMapper; + + public TaskDaoImpl(Storage storage, ObjectMapper objectMapper) { + this.storage = storage; + this.objectMapper = objectMapper; + } + + @Nullable + @Override + public Task get(long id, @Nullable TransactionHandle tx) throws SQLException { + return DbOperation.execute(tx, storage, c -> { + try (PreparedStatement ps = c.prepareStatement(SELECT_QUERY)) { + ps.setLong(1, id); + try (ResultSet rs = ps.executeQuery()) { + if (rs.next()) { + return readTask(rs); + } + return null; + } + } + }); + } + + @Override + @Nullable + public Task update(long id, Task.Update update, @Nullable TransactionHandle tx) throws SQLException { + if (update.isEmpty()) { + return null; + } + return DbOperation.execute(tx, storage, c -> { + try (PreparedStatement ps = c.prepareStatement(prepareUpdateQuery(update))) { + prepareUpdateParameters(ps, id, update); + var rs = ps.executeQuery(); + if (rs.next()) { + return readTask(rs); + } + return null; + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + + @Override + @Nullable + public Task updateLease(long id, Duration duration, @Nullable TransactionHandle tx) throws SQLException { + return DbOperation.execute(tx, storage, c -> { + try (PreparedStatement ps = c.prepareStatement(UPDATE_LEASE_QUERY)) { + ps.setString(1, duration.toString()); + ps.setLong(2, id); + var rs = ps.executeQuery(); + if (rs.next()) { + return readTask(rs); + } + return null; + } + }); + } + + @Nullable + @Override + public Task insert(Task task, @Nullable TransactionHandle tx) throws SQLException { + return DbOperation.execute(tx, storage, c -> { + try (PreparedStatement ps = c.prepareStatement(INSERT_QUERY)) { + var i = 0; + ps.setString(++i, task.name()); + ps.setString(++i, task.entityId()); + ps.setString(++i, task.type()); + ps.setString(++i, task.status().name()); + ps.setString(++i, objectMapper.writeValueAsString(task.metadata())); + ps.setString(++i, task.operationId()); + var rs = ps.executeQuery(); + if (rs.next()) { + return readTask(rs); + } + return null; + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + }); + } + + @Override + public void delete(long id, @Nullable TransactionHandle tx) throws SQLException { + DbOperation.execute(tx, storage, c -> { + try (PreparedStatement ps = c.prepareStatement(DELETE_QUERY)) { + ps.setLong(1, id); + ps.executeUpdate(); + } + }); + } + + @Override + public List lockPendingBatch(String ownerId, Duration leaseTime, int batchSize, + @Nullable TransactionHandle tx) throws SQLException + { + return DbOperation.execute(tx, storage, c -> { + try (PreparedStatement ps = c.prepareStatement(LOCK_PENDING_BATCH_QUERY)) { + var i = 0; + ps.setString(++i, ownerId); + ps.setString(++i, leaseTime.toString()); + ps.setInt(++i, batchSize); + try (var rs = ps.executeQuery()) { + var result = new ArrayList(rs.getFetchSize()); + while (rs.next()) { + result.add(readTask(rs)); + } + return result; + } + } + }); + } + + private static String prepareUpdateQuery(Task.Update update) { + var sb = new StringBuilder("UPDATE task SET "); + if (update.status() != null) { + sb.append("status = cast(? as task_status), "); + } + if (update.metadata() != null) { + sb.append("metadata = cast(? as jsonb), "); + } + if (update.operationId() != null) { + sb.append("operation_id = ?, "); + } + sb.setLength(sb.length() - 2); + sb.append(" WHERE id = ? RETURNING ").append(FIELDS); + return sb.toString(); + } + + private void prepareUpdateParameters(PreparedStatement ps, long id, Task.Update update) + throws JsonProcessingException, SQLException + { + var i = 0; + if (update.status() != null) { + ps.setString(++i, update.status().name()); + } + if (update.metadata() != null) { + ps.setString(++i, objectMapper.writeValueAsString(update.metadata())); + } + if (update.operationId() != null) { + ps.setString(++i, update.operationId()); + } + ps.setLong(++i, id); + } + + private Task readTask(ResultSet rs) throws SQLException { + try { + var id = rs.getLong("id"); + var name = rs.getString("name"); + var entityId = rs.getString("entity_id"); + var type = rs.getString("type"); + var status = Task.Status.valueOf(rs.getString("status")); + var createdAt = rs.getTimestamp("created_at").toInstant(); + var updatedAt = rs.getTimestamp("updated_at").toInstant(); + var metadata = objectMapper.readValue(rs.getString("metadata"), MAP_TYPE_REFERENCE); + var operationId = rs.getString("operation_id"); + var workerId = rs.getString("worker_id"); + var leaseTillTs = rs.getTimestamp("lease_till"); + var leaseTill = leaseTillTs != null ? leaseTillTs.toInstant() : null; + return new Task(id, name, entityId, type, status, createdAt, updatedAt, metadata, operationId, workerId, + leaseTill); + } catch (JsonProcessingException e) { + throw new RuntimeException("Cannot read metadata for task object", e); + } + + } +} diff --git a/lzy/long-running/src/test/java/ai/lzy/longrunning/TaskDaoImplTest.java b/lzy/long-running/src/test/java/ai/lzy/longrunning/TaskDaoImplTest.java new file mode 100644 index 0000000000..a4d200fe45 --- /dev/null +++ b/lzy/long-running/src/test/java/ai/lzy/longrunning/TaskDaoImplTest.java @@ -0,0 +1,139 @@ +package ai.lzy.longrunning; + +import ai.lzy.longrunning.task.Task; +import ai.lzy.longrunning.task.dao.TaskDaoImpl; +import ai.lzy.model.db.StorageImpl; +import ai.lzy.model.db.test.DatabaseTestUtils; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.zonky.test.db.postgres.junit.EmbeddedPostgresRules; +import io.zonky.test.db.postgres.junit.PreparedDbRule; +import org.junit.*; + +import java.time.Duration; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class TaskDaoImplTest { + + @Rule + public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(ds -> {}); + + private StorageImpl storage; + private TaskDaoImpl taskDao; + + + @Before + public void setup() { + storage = new StorageImpl(DatabaseTestUtils.preparePostgresConfig(db.getConnectionInfo()), + "classpath:db/migrations") {}; + taskDao = new TaskDaoImpl(storage, new ObjectMapper()); + } + + @After + public void teardown() { + DatabaseTestUtils.cleanup(storage); + storage.close(); + } + + @Test + public void create() throws Exception { + var task = taskDao.insert(Task.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); + var fetched = taskDao.get(task.id(), null); + assertEquals(task, fetched); + } + + @Test + public void multiCreate() throws Exception { + var task1 = taskDao.insert(Task.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); + var task2 = taskDao.insert(Task.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); + var task3 = taskDao.insert(Task.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); + Assert.assertTrue(task1.id() < task2.id()); + Assert.assertTrue(task2.id() < task3.id()); + } + + @Test + public void update() throws Exception { + var task = taskDao.insert(Task.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); + var updated = taskDao.update(task.id(), Task.Update.builder().status(Task.Status.RUNNING).build(), null); + assertEquals(Task.Status.RUNNING, updated.status()); + updated = taskDao.update(task.id(), Task.Update.builder().operationId("42").build(), null); + assertEquals("42", updated.operationId()); + updated = taskDao.update(task.id(), Task.Update.builder().metadata(Map.of("qux", "quux")).build(), null); + assertEquals(Map.of("qux", "quux"), updated.metadata()); + updated = taskDao.update(task.id(), Task.Update.builder() + .status(Task.Status.FINISHED) + .operationId("0") + .metadata(Map.of()) + .build(), null); + assertEquals(Task.Status.FINISHED, updated.status()); + assertEquals("0", updated.operationId()); + assertEquals(Map.of(), updated.metadata()); + } + + @Test + public void delete() throws Exception { + var task = taskDao.insert(Task.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); + taskDao.delete(task.id(), null); + } + + @Test + public void updateLease() throws Exception { + var task = taskDao.insert(Task.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); + var updateLease = taskDao.updateLease(task.id(), Duration.ofMinutes(5), null); + var between = Duration.between(task.createdAt(), updateLease.leaseTill()); + //leaseTill should be around 5 minutes from now + assertTrue(between.compareTo(Duration.ofMinutes(5)) >= 0); + } + + @Test + public void getUnknown() throws Exception { + var task = taskDao.get(42, null); + Assert.assertNull(task); + } + + @Test + public void updateUnknown() throws Exception { + var updated = taskDao.update(42, Task.Update.builder().status(Task.Status.RUNNING).build(), null); + Assert.assertNull(updated); + } + + @Test + public void deleteUnknown() throws Exception { + taskDao.delete(42, null); + } + + @Test + public void updateLeaseUnknown() throws Exception { + var updated = taskDao.updateLease(42, Duration.ofMinutes(5), null); + Assert.assertNull(updated); + } + + @Test + public void lockPendingBatch() throws Exception { + var task1 = taskDao.insert(Task.createPending("task1", "1", "MOUNT", Map.of()), null); + var task2 = taskDao.insert(Task.createPending("task2", "2", "MOUNT", Map.of()), null); + var task3 = taskDao.insert(Task.createPending("task3", "3", "MOUNT", Map.of()), null); + var task4 = taskDao.insert(Task.createPending("task4", "4", "MOUNT", Map.of()), null); + var task5 = taskDao.insert(Task.createPending("task5", "1", "MOUNT", Map.of()), null); + var task6 = taskDao.insert(Task.createPending("task6", "2", "MOUNT", Map.of()), null); + var task7 = taskDao.insert(Task.createPending("task7", "3", "MOUNT", Map.of()), null); + var task8 = taskDao.insert(Task.createPending("task8", "4", "MOUNT", Map.of()), null); + + task2 = taskDao.update(task2.id(), statusUpdate(Task.Status.RUNNING), null); + task3 = taskDao.update(task3.id(), statusUpdate(Task.Status.FINISHED), null); + task4 = taskDao.update(task4.id(), statusUpdate(Task.Status.FAILED), null); + + var lockedTasks = taskDao.lockPendingBatch("worker", Duration.ofMinutes(5), 10, null); + var lockedTaskIds = lockedTasks.stream().map(Task::id).collect(Collectors.toSet()); + assertEquals(Set.of(task1.id(), task7.id(), task8.id()), lockedTaskIds); + } + + private Task.Update statusUpdate(Task.Status status) { + return Task.Update.builder().status(status).build(); + } + +} diff --git a/lzy/long-running/src/test/resources/db/migrations/V1__task.sql b/lzy/long-running/src/test/resources/db/migrations/V1__task.sql new file mode 100644 index 0000000000..7c14e8c060 --- /dev/null +++ b/lzy/long-running/src/test/resources/db/migrations/V1__task.sql @@ -0,0 +1,18 @@ +CREATE TYPE task_status AS ENUM ('PENDING', 'RUNNING', 'FAILED', 'FINISHED'); + +CREATE TYPE task_type AS ENUM ('UNMOUNT', 'MOUNT'); + +CREATE TABLE IF NOT EXISTS task( + id BIGSERIAL NOT NULL, + name TEXT NOT NULL, + entity_id TEXT NOT NULL, + type task_type NOT NULL, + status task_status NOT NULL, + created_at TIMESTAMP NOT NULL, + updated_at TIMESTAMP NOT NULL, + metadata JSONB NOT NULL, + operation_id TEXT, + worker_id TEXT, + lease_till TIMESTAMP, + PRIMARY KEY (id) +); \ No newline at end of file diff --git a/lzy/long-running/src/test/resources/log4j2.yaml b/lzy/long-running/src/test/resources/log4j2.yaml new file mode 100644 index 0000000000..c1554d579e --- /dev/null +++ b/lzy/long-running/src/test/resources/log4j2.yaml @@ -0,0 +1,28 @@ +Configuration: + status: warn + + Appenders: + Console: + name: Console + target: SYSTEM_OUT + PatternLayout: + Pattern: "%d{yyyy-MM-dd HH:mm:ss.SSS}{UTC} [%t] %-5level %logger{36} %notEmpty{[rid=%X{rid}] }- %msg%n" + + Loggers: + Root: + level: warn + AppenderRef: + ref: Console + + Logger: + - name: UserEventLogs + level: error + + - name: "ai.lzy.model.utils.FreePortFinder" + level: error + + - name: io.zonky.test + level: warn + + - name: org.flywaydb + level: warn diff --git a/util/util-db/src/main/java/ai/lzy/model/db/test/DatabaseTestUtils.java b/util/util-db/src/main/java/ai/lzy/model/db/test/DatabaseTestUtils.java index 8cb35a7372..6bd24d3d62 100644 --- a/util/util-db/src/main/java/ai/lzy/model/db/test/DatabaseTestUtils.java +++ b/util/util-db/src/main/java/ai/lzy/model/db/test/DatabaseTestUtils.java @@ -1,5 +1,6 @@ package ai.lzy.model.db.test; +import ai.lzy.model.db.DatabaseConfiguration; import ai.lzy.model.db.Storage; import java.sql.SQLException; @@ -29,6 +30,28 @@ public static HashMap preparePostgresConfig(String app, Object c } } + public static DatabaseConfiguration preparePostgresConfig(Object ci) { + /* io.zonky.test.db.postgres.embedded.ConnectionInfo ci */ + try { + var user = getFieldValue(ci, "user", String.class); + assert "postgres".equals(user); + + var port = getFieldValue(ci, "port", Integer.class); + var dbName = getFieldValue(ci, "dbName", String.class); + + var result = new DatabaseConfiguration(); + result.setEnabled(true); + result.setUrl("jdbc:postgresql://localhost:%d/%s".formatted(port, dbName)); + result.setUsername("postgres"); + result.setPassword(""); + result.setMinPoolSize(1); + result.setMaxPoolSize(10); + return result; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + @SuppressWarnings("checkstyle:Indentation") public static HashMap prepareLocalhostConfig(String app) { return new HashMap<>() {{ From d7a2b24c0e7bdbde7f9e970540094cd2ea3558a7 Mon Sep 17 00:00:00 2001 From: Hoxo <29162640+Hoxo@users.noreply.github.com> Date: Mon, 29 May 2023 18:05:44 +0300 Subject: [PATCH 2/9] dao task queue test --- lzy/long-running/pom.xml | 5 + .../ai/lzy/longrunning/task/DaoTaskQueue.java | 30 ++-- .../lzy/longrunning/task/TaskAwareAction.java | 6 +- .../ai/lzy/longrunning/task/TaskQueue.java | 8 +- .../ai/lzy/longrunning/DaoTaskQueueTest.java | 129 ++++++++++++++++++ .../ai/lzy/longrunning/TaskDaoImplTest.java | 20 +++ 6 files changed, 182 insertions(+), 16 deletions(-) create mode 100644 lzy/long-running/src/test/java/ai/lzy/longrunning/DaoTaskQueueTest.java diff --git a/lzy/long-running/pom.xml b/lzy/long-running/pom.xml index f295c97eb7..634dbb4d59 100644 --- a/lzy/long-running/pom.xml +++ b/lzy/long-running/pom.xml @@ -50,6 +50,11 @@ embedded-postgres test + + org.mockito + mockito-core + test + diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/DaoTaskQueue.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/DaoTaskQueue.java index 0183170792..dac5a11a2c 100644 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/DaoTaskQueue.java +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/DaoTaskQueue.java @@ -18,14 +18,14 @@ public class DaoTaskQueue implements TaskQueue { private final TaskDao taskDao; private final int maxQueueSize; - private final Duration leaseTime; + private final Duration initialLeaseTime; private final String instanceId; private final Queue queue; - public DaoTaskQueue(TaskDao taskDao, int maxQueueSize, Duration leaseTime, String instanceId) { + public DaoTaskQueue(TaskDao taskDao, int maxQueueSize, Duration initialLeaseTime, String instanceId) { this.taskDao = taskDao; this.maxQueueSize = maxQueueSize; - this.leaseTime = leaseTime; + this.initialLeaseTime = initialLeaseTime; this.instanceId = instanceId; this.queue = Queues.newConcurrentLinkedQueue(); } @@ -44,16 +44,16 @@ private int capacity() { private List loadPendingTasks(int toLoad) { try { - return withRetries(LOG, () -> taskDao.lockPendingBatch(instanceId, leaseTime, toLoad, null)); + return withRetries(LOG, () -> taskDao.lockPendingBatch(instanceId, initialLeaseTime, toLoad, null)); } catch (Exception e) { throw new RuntimeException(e); } } @Override - public void add(Task task) { + public Task add(Task task) { try { - withRetries(LOG, () -> taskDao.insert(task, null)); + return withRetries(LOG, () -> taskDao.insert(task, null)); } catch (Exception e) { throw new RuntimeException(); } @@ -70,6 +70,7 @@ public Task pollNext() { return next; } + @Override public List pollRemaining() { if (!queue.isEmpty()) { var result = queue.stream().toList(); @@ -81,18 +82,27 @@ public List pollRemaining() { } @Override - public void update(long id, Task.Update update) { + public Task update(long id, Task.Update update) { + try { + return withRetries(LOG, () -> taskDao.update(id, update, null)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public Task updateLease(long taskId, Duration duration) { try { - withRetries(LOG, () -> taskDao.update(id, update, null)); + return withRetries(LOG, () -> taskDao.updateLease(taskId, duration, null)); } catch (Exception e) { throw new RuntimeException(e); } } @Override - public void updateLease(Task task, Duration duration) { + public void delete(long id) { try { - withRetries(LOG, () -> taskDao.updateLease(task.id(), duration, null)); + withRetries(LOG, () -> taskDao.delete(id, null)); } catch (Exception e) { throw new RuntimeException(e); } diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskAwareAction.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskAwareAction.java index 6c9d97fdd6..e7800c6311 100644 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskAwareAction.java +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskAwareAction.java @@ -10,9 +10,9 @@ import java.util.Map; public abstract class TaskAwareAction extends OperationRunnerBase { - private final Task task; private final TaskQueue queue; private final Duration leaseDuration; + private Task task; public TaskAwareAction(Task task, TaskQueue queue, Duration leaseDuration, String opId, String desc, Storage storage, OperationDao operationsDao, OperationsExecutor executor) @@ -40,7 +40,7 @@ protected Task task() { @Override protected void beforeStep() { super.beforeStep(); - queue.updateLease(task, leaseDuration); + task = queue.updateLease(task.id(), leaseDuration); } @Override @@ -53,6 +53,6 @@ protected void notifyFinished(@Nullable Throwable t) { } else { builder.status(Task.Status.FINISHED); } - queue.update(task.id(), builder.build()); + task = queue.update(task.id(), builder.build()); } } diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskQueue.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskQueue.java index f76e58d592..a6c5feba09 100644 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskQueue.java +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskQueue.java @@ -6,14 +6,16 @@ import java.util.List; public interface TaskQueue { - void add(Task task); + Task add(Task task); @Nullable Task pollNext(); List pollRemaining(); - void update(long id, Task.Update update); + Task update(long id, Task.Update update); - void updateLease(Task task, Duration duration); + Task updateLease(long taskId, Duration duration); + + void delete(long id); } diff --git a/lzy/long-running/src/test/java/ai/lzy/longrunning/DaoTaskQueueTest.java b/lzy/long-running/src/test/java/ai/lzy/longrunning/DaoTaskQueueTest.java new file mode 100644 index 0000000000..b1d94af886 --- /dev/null +++ b/lzy/long-running/src/test/java/ai/lzy/longrunning/DaoTaskQueueTest.java @@ -0,0 +1,129 @@ +package ai.lzy.longrunning; + +import ai.lzy.longrunning.task.DaoTaskQueue; +import ai.lzy.longrunning.task.Task; +import ai.lzy.longrunning.task.dao.TaskDao; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; + +import java.time.Duration; +import java.time.temporal.ChronoUnit; +import java.util.List; +import java.util.Map; + +public class DaoTaskQueueTest { + + private static final int MAX_QUEUE_SIZE = 10; + private static final Duration INITIAL_LEASE_TIME = Duration.of(5, ChronoUnit.MINUTES); + private static final String INSTANCE_ID = "worker"; + + private DaoTaskQueue taskQueue; + private TaskDao taskDaoMock; + + @Before + public void setup() { + taskDaoMock = Mockito.mock(TaskDao.class); + taskQueue = new DaoTaskQueue(taskDaoMock, MAX_QUEUE_SIZE, INITIAL_LEASE_TIME, + INSTANCE_ID); + } + + @Test + public void add() throws Exception { + Mockito.when(taskDaoMock.insert(Mockito.any(), Mockito.any())).thenReturn(null); + var task = Task.createPending("task", "1", "MOUNT", Map.of()); + taskQueue.add(task); + Mockito.verify(taskDaoMock, Mockito.only()).insert(task, null); + } + + @Test + public void update() throws Exception { + Mockito.when(taskDaoMock.update(Mockito.anyLong(), Mockito.any(), Mockito.any())).thenReturn(null); + var task = Task.createPending("task", "1", "MOUNT", Map.of()); + var update = Task.Update.builder().status(Task.Status.RUNNING).build(); + taskQueue.update(task.id(), update); + Mockito.verify(taskDaoMock, Mockito.only()).update(task.id(), update, null); + } + + @Test + public void updateLease() throws Exception { + Mockito.when(taskDaoMock.update(Mockito.anyLong(), Mockito.any(), Mockito.any())).thenReturn(null); + var task = Task.createPending("task", "1", "MOUNT", Map.of()); + var duration = Duration.ofSeconds(42); + taskQueue.updateLease(task.id(), duration); + Mockito.verify(taskDaoMock, Mockito.only()).updateLease(task.id(), duration, null); + } + + @Test + public void pollNext() throws Exception { + var tasks = List.of( + Task.createPending("task1", "1", "MOUNT", Map.of()), + Task.createPending("task2", "2", "MOUNT", Map.of()), + Task.createPending("task3", "3", "MOUNT", Map.of()) + ); + Mockito.when(taskDaoMock.lockPendingBatch(Mockito.anyString(), Mockito.any(), Mockito.anyInt(), Mockito.any())) + .thenReturn(tasks); + var polledTask1 = taskQueue.pollNext(); + Mockito.verify(taskDaoMock, Mockito.only()) + .lockPendingBatch(INSTANCE_ID, INITIAL_LEASE_TIME, MAX_QUEUE_SIZE, null); + var polledTask2 = taskQueue.pollNext(); + var polledTask3 = taskQueue.pollNext(); + Mockito.verify(taskDaoMock, Mockito.only()) + .lockPendingBatch(INSTANCE_ID, INITIAL_LEASE_TIME, MAX_QUEUE_SIZE, null); + Assert.assertEquals(tasks.get(0), polledTask1); + Assert.assertEquals(tasks.get(1), polledTask2); + Assert.assertEquals(tasks.get(2), polledTask3); + + var anotherBatch = List.of(Task.createPending("task4", "4", "MOUNT", Map.of())); + Mockito.when(taskDaoMock.lockPendingBatch(Mockito.anyString(), Mockito.any(), Mockito.anyInt(), Mockito.any())) + .thenReturn(anotherBatch); + var polledTask4 = taskQueue.pollNext(); + Mockito.verify(taskDaoMock, Mockito.times(2)) + .lockPendingBatch(INSTANCE_ID, INITIAL_LEASE_TIME, MAX_QUEUE_SIZE, null); + Assert.assertEquals(anotherBatch.get(0), polledTask4); + + Mockito.when(taskDaoMock.lockPendingBatch(Mockito.anyString(), Mockito.any(), Mockito.anyInt(), Mockito.any())) + .thenReturn(List.of()); + var polledTask5 = taskQueue.pollNext(); + Mockito.verify(taskDaoMock, Mockito.times(3)) + .lockPendingBatch(INSTANCE_ID, INITIAL_LEASE_TIME, MAX_QUEUE_SIZE, null); + Assert.assertNull(polledTask5); + } + + @Test + public void delete() throws Exception { + Mockito.doNothing().when(taskDaoMock).delete(Mockito.anyLong(), Mockito.any()); + var task = Task.createPending("task", "1", "MOUNT", Map.of()); + taskQueue.delete(task.id()); + Mockito.verify(taskDaoMock, Mockito.only()).delete(task.id(), null); + } + + @Test + public void pollRemaining() throws Exception { + var tasks = List.of( + Task.createPending("task1", "1", "MOUNT", Map.of()), + Task.createPending("task2", "2", "MOUNT", Map.of()), + Task.createPending("task3", "3", "MOUNT", Map.of()) + ); + Mockito.when(taskDaoMock.lockPendingBatch(Mockito.anyString(), Mockito.any(), Mockito.anyInt(), Mockito.any())) + .thenReturn(tasks); + var task = taskQueue.pollNext(); + Assert.assertEquals(tasks.get(0), task); + Mockito.verify(taskDaoMock, Mockito.only()) + .lockPendingBatch(INSTANCE_ID, INITIAL_LEASE_TIME, MAX_QUEUE_SIZE, null); + + var remainingTasks = taskQueue.pollRemaining(); + Mockito.verify(taskDaoMock, Mockito.only()) + .lockPendingBatch(INSTANCE_ID, INITIAL_LEASE_TIME, MAX_QUEUE_SIZE, null); + Assert.assertEquals(tasks.subList(1, 3), remainingTasks); + + Mockito.when(taskDaoMock.lockPendingBatch(Mockito.anyString(), Mockito.any(), Mockito.anyInt(), Mockito.any())) + .thenReturn(List.of()); + var newTasks = taskQueue.pollRemaining(); + Mockito.verify(taskDaoMock, Mockito.times(2)) + .lockPendingBatch(INSTANCE_ID, INITIAL_LEASE_TIME, MAX_QUEUE_SIZE, null); + Assert.assertEquals(List.of(), newTasks); + } + +} diff --git a/lzy/long-running/src/test/java/ai/lzy/longrunning/TaskDaoImplTest.java b/lzy/long-running/src/test/java/ai/lzy/longrunning/TaskDaoImplTest.java index a4d200fe45..6d8f9d09ea 100644 --- a/lzy/long-running/src/test/java/ai/lzy/longrunning/TaskDaoImplTest.java +++ b/lzy/long-running/src/test/java/ai/lzy/longrunning/TaskDaoImplTest.java @@ -132,6 +132,26 @@ public void lockPendingBatch() throws Exception { assertEquals(Set.of(task1.id(), task7.id(), task8.id()), lockedTaskIds); } + @Test + public void lockPendingBatchWithAllRunning() throws Exception { + var task1 = taskDao.insert(Task.createPending("task1", "1", "MOUNT", Map.of()), null); + var task2 = taskDao.insert(Task.createPending("task2", "2", "MOUNT", Map.of()), null); + var task3 = taskDao.insert(Task.createPending("task3", "3", "MOUNT", Map.of()), null); + var task4 = taskDao.insert(Task.createPending("task4", "4", "MOUNT", Map.of()), null); + var task5 = taskDao.insert(Task.createPending("task5", "1", "MOUNT", Map.of()), null); + var task6 = taskDao.insert(Task.createPending("task6", "2", "MOUNT", Map.of()), null); + var task7 = taskDao.insert(Task.createPending("task7", "3", "MOUNT", Map.of()), null); + var task8 = taskDao.insert(Task.createPending("task8", "4", "MOUNT", Map.of()), null); + + task1 = taskDao.update(task1.id(), statusUpdate(Task.Status.RUNNING), null); + task2 = taskDao.update(task2.id(), statusUpdate(Task.Status.RUNNING), null); + task3 = taskDao.update(task3.id(), statusUpdate(Task.Status.RUNNING), null); + task4 = taskDao.update(task4.id(), statusUpdate(Task.Status.RUNNING), null); + + var lockedTasks = taskDao.lockPendingBatch("worker", Duration.ofMinutes(5), 10, null); + assertTrue(lockedTasks.isEmpty()); + } + private Task.Update statusUpdate(Task.Status status) { return Task.Update.builder().status(status).build(); } From 0afbb719fce7bc62f89e2ba87c36ff85b3c27d2b Mon Sep 17 00:00:00 2001 From: Hoxo <29162640+Hoxo@users.noreply.github.com> Date: Mon, 31 Jul 2023 00:18:29 +0300 Subject: [PATCH 3/9] new naming --- .../{V8__task.sql => V9__operation_task.sql} | 6 +- .../ai/lzy/longrunning/task/DaoTaskQueue.java | 110 ---------- .../DispatchingOperationTaskResolver.java | 50 +++++ .../longrunning/task/OpTaskAwareAction.java | 76 +++++++ .../task/{Task.java => OperationTask.java} | 15 +- .../task/OperationTaskExecutor.java | 193 ++++++++++++++++++ .../task/OperationTaskResolver.java | 36 ++++ .../lzy/longrunning/task/ResolverUtils.java | 19 ++ .../lzy/longrunning/task/TaskAwareAction.java | 58 ------ .../ai/lzy/longrunning/task/TaskExecutor.java | 63 ------ .../longrunning/task/TaskMetricsProvider.java | 9 + .../ai/lzy/longrunning/task/TaskQueue.java | 21 -- .../ai/lzy/longrunning/task/TaskResolver.java | 5 - .../task/TypedOperationTaskResolver.java | 5 + .../task/dao/OperationTaskDao.java | 36 ++++ ...DaoImpl.java => OperationTaskDaoImpl.java} | 140 ++++++++++--- .../ai/lzy/longrunning/task/dao/TaskDao.java | 29 --- .../ai/lzy/longrunning/DaoTaskQueueTest.java | 129 ------------ .../longrunning/OperationTaskDaoImplTest.java | 159 +++++++++++++++ .../ai/lzy/longrunning/TaskDaoImplTest.java | 159 --------------- .../test/resources/db/migrations/V1__task.sql | 2 +- 21 files changed, 702 insertions(+), 618 deletions(-) rename lzy/allocator/src/main/resources/db/allocator/migrations/{V8__task.sql => V9__operation_task.sql} (79%) delete mode 100644 lzy/long-running/src/main/java/ai/lzy/longrunning/task/DaoTaskQueue.java create mode 100644 lzy/long-running/src/main/java/ai/lzy/longrunning/task/DispatchingOperationTaskResolver.java create mode 100644 lzy/long-running/src/main/java/ai/lzy/longrunning/task/OpTaskAwareAction.java rename lzy/long-running/src/main/java/ai/lzy/longrunning/task/{Task.java => OperationTask.java} (74%) create mode 100644 lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskExecutor.java create mode 100644 lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskResolver.java create mode 100644 lzy/long-running/src/main/java/ai/lzy/longrunning/task/ResolverUtils.java delete mode 100644 lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskAwareAction.java delete mode 100644 lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskExecutor.java create mode 100644 lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskMetricsProvider.java delete mode 100644 lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskQueue.java delete mode 100644 lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskResolver.java create mode 100644 lzy/long-running/src/main/java/ai/lzy/longrunning/task/TypedOperationTaskResolver.java create mode 100644 lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/OperationTaskDao.java rename lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/{TaskDaoImpl.java => OperationTaskDaoImpl.java} (57%) delete mode 100644 lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/TaskDao.java delete mode 100644 lzy/long-running/src/test/java/ai/lzy/longrunning/DaoTaskQueueTest.java create mode 100644 lzy/long-running/src/test/java/ai/lzy/longrunning/OperationTaskDaoImplTest.java delete mode 100644 lzy/long-running/src/test/java/ai/lzy/longrunning/TaskDaoImplTest.java diff --git a/lzy/allocator/src/main/resources/db/allocator/migrations/V8__task.sql b/lzy/allocator/src/main/resources/db/allocator/migrations/V9__operation_task.sql similarity index 79% rename from lzy/allocator/src/main/resources/db/allocator/migrations/V8__task.sql rename to lzy/allocator/src/main/resources/db/allocator/migrations/V9__operation_task.sql index 7a5e930a57..8149674abc 100644 --- a/lzy/allocator/src/main/resources/db/allocator/migrations/V8__task.sql +++ b/lzy/allocator/src/main/resources/db/allocator/migrations/V9__operation_task.sql @@ -1,8 +1,8 @@ -CREATE TYPE task_status AS ENUM ('PENDING', 'RUNNING', 'FAILED', 'FINISHED'); +CREATE TYPE task_status AS ENUM ('PENDING', 'RUNNING', 'FAILED', 'FINISHED', 'STALE'); CREATE TYPE task_type AS ENUM ('UNMOUNT', 'MOUNT'); -CREATE TABLE IF NOT EXISTS task( +CREATE TABLE IF NOT EXISTS operation_task( id BIGSERIAL NOT NULL, name TEXT NOT NULL, entity_id TEXT NOT NULL, @@ -18,4 +18,4 @@ CREATE TABLE IF NOT EXISTS task( FOREIGN KEY (operation_id) REFERENCES operation(id) ); -CREATE INDEX IF NOT EXISTS task_status_entity_id_idx ON task(status, entity_id, id); +CREATE INDEX IF NOT EXISTS task_status_entity_id_idx ON operation_task(status, entity_id, id); diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/DaoTaskQueue.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/DaoTaskQueue.java deleted file mode 100644 index dac5a11a2c..0000000000 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/DaoTaskQueue.java +++ /dev/null @@ -1,110 +0,0 @@ -package ai.lzy.longrunning.task; - -import ai.lzy.longrunning.task.dao.TaskDao; -import com.google.common.collect.Queues; -import jakarta.annotation.Nullable; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; - -import java.time.Duration; -import java.util.List; -import java.util.Queue; - -import static ai.lzy.model.db.DbHelper.withRetries; - -public class DaoTaskQueue implements TaskQueue { - - private static final Logger LOG = LogManager.getLogger(TaskExecutor.class); - - private final TaskDao taskDao; - private final int maxQueueSize; - private final Duration initialLeaseTime; - private final String instanceId; - private final Queue queue; - - public DaoTaskQueue(TaskDao taskDao, int maxQueueSize, Duration initialLeaseTime, String instanceId) { - this.taskDao = taskDao; - this.maxQueueSize = maxQueueSize; - this.initialLeaseTime = initialLeaseTime; - this.instanceId = instanceId; - this.queue = Queues.newConcurrentLinkedQueue(); - } - - private void loadNextBatch() { - var toLoad = capacity(); - if (toLoad > 0) { - var tasks = loadPendingTasks(toLoad); - queue.addAll(tasks); - } - } - - private int capacity() { - return maxQueueSize - queue.size(); - } - - private List loadPendingTasks(int toLoad) { - try { - return withRetries(LOG, () -> taskDao.lockPendingBatch(instanceId, initialLeaseTime, toLoad, null)); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public Task add(Task task) { - try { - return withRetries(LOG, () -> taskDao.insert(task, null)); - } catch (Exception e) { - throw new RuntimeException(); - } - } - - @Nullable - @Override - public Task pollNext() { - var next = queue.poll(); - if (next == null) { - loadNextBatch(); - next = queue.poll(); - } - return next; - } - - @Override - public List pollRemaining() { - if (!queue.isEmpty()) { - var result = queue.stream().toList(); - queue.clear(); - return result; - } - //don't load to queue, just return loaded tasks - return loadPendingTasks(capacity()); - } - - @Override - public Task update(long id, Task.Update update) { - try { - return withRetries(LOG, () -> taskDao.update(id, update, null)); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public Task updateLease(long taskId, Duration duration) { - try { - return withRetries(LOG, () -> taskDao.updateLease(taskId, duration, null)); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public void delete(long id) { - try { - withRetries(LOG, () -> taskDao.delete(id, null)); - } catch (Exception e) { - throw new RuntimeException(e); - } - } -} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/DispatchingOperationTaskResolver.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/DispatchingOperationTaskResolver.java new file mode 100644 index 0000000000..83c1d83621 --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/DispatchingOperationTaskResolver.java @@ -0,0 +1,50 @@ +package ai.lzy.longrunning.task; + +import ai.lzy.model.db.TransactionHandle; +import jakarta.annotation.Nullable; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.sql.SQLException; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class DispatchingOperationTaskResolver implements OperationTaskResolver { + + private static final Logger LOG = LogManager.getLogger(DispatchingOperationTaskResolver.class); + private final Map resolvers; + + public DispatchingOperationTaskResolver(List resolvers) { + this.resolvers = generateResolversMap(resolvers); + } + + private static Map generateResolversMap( + List resolvers) + { + var types = new HashSet(); + resolvers.forEach(r -> { + if (!types.add(r.type())) { + throw new IllegalStateException("Duplicate resolver for type " + r.type()); + } + }); + return resolvers.stream() + .collect(Collectors.toMap(TypedOperationTaskResolver::type, r -> r)); + } + + @Override + public Result resolve(OperationTask operationTask, @Nullable TransactionHandle tx) throws SQLException { + var resolver = resolvers.get(operationTask.type()); + if (resolver == null) { + LOG.error("No resolver for task type {}. Task: {}", operationTask.type(), operationTask); + return Result.UNKNOWN_TASK; + } + try { + return resolver.resolve(operationTask, tx); + } catch (Exception e) { + LOG.error("Error while resolving task {}", operationTask.id(), e); + return Result.resolveError(e); + } + } +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OpTaskAwareAction.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OpTaskAwareAction.java new file mode 100644 index 0000000000..6e262e6f1d --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OpTaskAwareAction.java @@ -0,0 +1,76 @@ +package ai.lzy.longrunning.task; + +import ai.lzy.longrunning.OperationRunnerBase; +import ai.lzy.longrunning.OperationsExecutor; +import ai.lzy.longrunning.dao.OperationDao; +import ai.lzy.longrunning.task.dao.OperationTaskDao; +import ai.lzy.model.db.Storage; +import jakarta.annotation.Nullable; + +import java.time.Duration; +import java.util.Map; + +import static ai.lzy.model.db.DbHelper.withRetries; + +public abstract class OpTaskAwareAction extends OperationRunnerBase { + private final OperationTaskDao operationTaskDao; + private final Duration leaseDuration; + private OperationTask operationTask; + + public OpTaskAwareAction(OperationTask operationTask, OperationTaskDao operationTaskDao, Duration leaseDuration, + String opId, String desc, Storage storage, OperationDao operationsDao, + OperationsExecutor executor) + { + super(opId, desc, storage, operationsDao, executor); + this.operationTask = operationTask; + this.operationTaskDao = operationTaskDao; + this.leaseDuration = leaseDuration; + } + + @Override + protected Map prepareLogContext() { + var ctx = super.prepareLogContext(); + ctx.put("task_id", String.valueOf(operationTask.id())); + ctx.put("task_type", operationTask.type()); + ctx.put("task_name", operationTask.name()); + ctx.put("task_entity_id", operationTask.entityId()); + return ctx; + } + + protected OperationTask task() { + return operationTask; + } + + public void setTask(OperationTask operationTask) { + this.operationTask = operationTask; + } + + @Override + protected void beforeStep() { + super.beforeStep(); + try { + operationTask = withRetries(log(), () -> operationTaskDao.updateLease(operationTask.id(), leaseDuration, + null)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + protected void notifyFinished(@Nullable Throwable t) { + super.notifyFinished(t); + + var builder = OperationTask.Update.builder(); + if (t != null) { + builder.status(OperationTask.Status.FAILED); + } else { + builder.status(OperationTask.Status.FINISHED); + } + try { + operationTask = withRetries(log(), () -> operationTaskDao.update(operationTask.id(), builder.build(), + null)); + } catch (Exception e) { + log().error("{} Couldn't finish operation task {}", logPrefix(), task().id()); + } + } +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/Task.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTask.java similarity index 74% rename from lzy/long-running/src/main/java/ai/lzy/longrunning/task/Task.java rename to lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTask.java index 5b7d27cbc7..8fc130a2d3 100644 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/Task.java +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTask.java @@ -5,7 +5,7 @@ import java.time.Instant; import java.util.Map; -public record Task( +public record OperationTask( long id, String name, String entityId, @@ -21,16 +21,17 @@ public record Task( @Nullable Instant leaseTill ) { - public static Task createPending(String name, String entityId, String type, Map metadata) { - return new Task(-1, name, entityId, type, Status.PENDING, Instant.now(), Instant.now(), + public static OperationTask createPending(String name, String entityId, String type, Map metadata) { + return new OperationTask(-1, name, entityId, type, Status.PENDING, Instant.now(), Instant.now(), metadata, null, null, null); } public enum Status { - PENDING, - RUNNING, - FAILED, - FINISHED, + PENDING, //just created and waiting to be executed + RUNNING, //acquired by one of the instance of app and executing + FINISHED, //successful finish + FAILED, //unrecoverable failure + STALE, //executed too late and not applicable anymore } public record Update( diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskExecutor.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskExecutor.java new file mode 100644 index 0000000000..4ab0a53943 --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskExecutor.java @@ -0,0 +1,193 @@ +package ai.lzy.longrunning.task; + +import ai.lzy.longrunning.OperationsExecutor; +import ai.lzy.longrunning.task.dao.OperationTaskDao; +import ai.lzy.model.db.DbHelper; +import ai.lzy.model.db.Storage; +import ai.lzy.model.db.TransactionHandle; +import jakarta.annotation.Nullable; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.sql.SQLException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; + +public class OperationTaskExecutor { + + private static final Logger LOG = LogManager.getLogger(OperationTaskExecutor.class); + + private final OperationTaskDao opTaskDao; + private final OperationsExecutor operationsExecutor; + private final OperationTaskResolver resolver; + private final ScheduledExecutorService scheduler; + private final Duration initialDelay; + private final Duration executionDelay; + private final Storage storage; + private final TaskMetricsProvider metricsProvider; + private final String instanceId; + private final Duration leaseDuration; + private final int batchSize; + + private volatile boolean started = false; + private volatile boolean disabled = false; + + public OperationTaskExecutor(OperationTaskDao opTaskDao, OperationsExecutor operationsExecutor, + OperationTaskResolver resolver, Duration initialDelay, Duration executionDelay, + Storage storage, TaskMetricsProvider metricsProvider, String instanceId, + Duration leaseDuration, int batchSize) + { + this.opTaskDao = opTaskDao; + this.operationsExecutor = operationsExecutor; + this.resolver = resolver; + this.initialDelay = initialDelay; + this.executionDelay = executionDelay; + this.storage = storage; + this.leaseDuration = leaseDuration; + this.batchSize = batchSize; + this.scheduler = Executors.newSingleThreadScheduledExecutor(); //it's important to have only one thread + this.metricsProvider = metricsProvider; + this.instanceId = instanceId; + } + + public void start() { + if (started) { + throw new IllegalStateException("Task executor has already started!"); + } + started = true; + restoreTasks(); + startMailLoop(); + } + + //todo backpressure - to not start new tasks if there are too many of them + private ScheduledFuture startMailLoop() { + return scheduler.scheduleWithFixedDelay(() -> { + try { + var actions = new ArrayList(); + DbHelper.withRetries(LOG, () -> { + try (var tx = TransactionHandle.create(storage)) { + for (OperationTask operationTask : opTaskDao.lockPendingBatch(instanceId, leaseDuration, + batchSize, tx)) + { + if (disabled) { + return; + } + var taskAwareAction = resolveTask(operationTask, tx); + if (taskAwareAction != null) { + actions.add(taskAwareAction); + } + } + tx.commit(); + } + }); + actions.forEach(operationsExecutor::startNew); + } catch (Exception e) { + LOG.error("Got exception while scheduling task", e); + metricsProvider.schedulerErrors(instanceId).inc(); + } + }, initialDelay.toMillis(), executionDelay.toMillis(), TimeUnit.MILLISECONDS); + } + + @Nullable + private OpTaskAwareAction resolveTask(OperationTask operationTask, TransactionHandle tx) throws SQLException { + var resolveResult = resolver.resolve(operationTask, tx); + if (resolveResult.status() != OperationTaskResolver.Status.SUCCESS) { + metricsProvider.schedulerResolveErrors(instanceId, resolveResult.status()); + } + switch (resolveResult.status()) { + case SUCCESS -> { + var updatedTask = setStatus(operationTask, OperationTask.Status.RUNNING, tx); + var action = resolveResult.action(); + assert action != null; + action.setTask(updatedTask); + return action; + } + case STALE -> { + LOG.warn("Marking task {} as STALE", operationTask.id(), resolveResult.exception()); + setStatus(operationTask, OperationTask.Status.STALE, tx); + } + case BAD_STATE -> { + LOG.error("Marking task {} as FAILED", operationTask.id(), resolveResult.exception()); + setStatus(operationTask, OperationTask.Status.FAILED, tx); + } + case UNKNOWN_TASK, RESOLVE_ERROR -> { + LOG.warn("Couldn't resolve task {}", operationTask.id(), resolveResult.exception()); + } + } + return null; + } + + private void restoreTasks() { + try { + var actionsToRun = new ArrayList(); + DbHelper.withRetries(LOG, () -> { + try (var tx = TransactionHandle.create(storage)) { + var tasks = opTaskDao.recaptureOldTasks(instanceId, leaseDuration, tx); + for (OperationTask operationTask : tasks) { + var taskAwareAction = resolveTask(operationTask, tx); + if (taskAwareAction != null) { + actionsToRun.add(taskAwareAction); + } + } + tx.commit(); + } + }); + actionsToRun.forEach(operationsExecutor::startNew); + } catch (Exception e) { + LOG.error("Got exception while restoring tasks", e); + metricsProvider.schedulerErrors(instanceId).inc(); + } + } + + private OperationTask setStatus(OperationTask operationTask, OperationTask.Status status, TransactionHandle tx) + throws SQLException + { + return opTaskDao.update(operationTask.id(), OperationTask.Update.builder() + .status(status) + .build(), tx); + } + + public void saveTask(OperationTask operationTask, @Nullable TransactionHandle tx) throws SQLException { + opTaskDao.insert(operationTask, tx); + } + + public ScheduledFuture startImmediately(OperationTask opTask) { + //schedule runnable to start task immediately. + //if task is already captured by main loop, this runnable will exit + return scheduler.schedule(() -> { + try { + var action = DbHelper.withRetries(LOG, () -> { + try (var tx = TransactionHandle.create(storage)) { + var lockedTask = opTaskDao.tryLockTask(opTask.id(), + opTask.entityId(), instanceId, leaseDuration, + tx); + if (lockedTask == null) { + return null; + } + var taskAwareAction = resolveTask(lockedTask, tx); + tx.commit(); + return taskAwareAction; + } + }); + if (action == null) { + return false; + } + operationsExecutor.startNew(action); + return true; + } catch (Exception e) { + LOG.error("Got exception while scheduling task", e); + metricsProvider.schedulerErrors(instanceId).inc(); + return false; + } + }, 0, TimeUnit.MILLISECONDS); + } + + public void shutdown() { + disabled = true; + scheduler.shutdown(); + } +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskResolver.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskResolver.java new file mode 100644 index 0000000000..bc0208c555 --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskResolver.java @@ -0,0 +1,36 @@ +package ai.lzy.longrunning.task; + +import ai.lzy.model.db.TransactionHandle; +import jakarta.annotation.Nullable; + +import java.sql.SQLException; + +public interface OperationTaskResolver { + Result resolve(OperationTask operationTask, @Nullable TransactionHandle tx) throws SQLException; + + enum Status { + SUCCESS, + STALE, + BAD_STATE, + UNKNOWN_TASK, + RESOLVE_ERROR, + } + + record Result( + @Nullable OpTaskAwareAction action, + Status status, + @Nullable Exception exception + ) { + public static final Result STALE = new Result(null, Status.STALE, null); + public static final Result BAD_STATE = new Result(null, Status.BAD_STATE, null); + public static final Result UNKNOWN_TASK = new Result(null, Status.UNKNOWN_TASK, null); + + public static Result success(OpTaskAwareAction action) { + return new Result(action, Status.SUCCESS, null); + } + + public static Result resolveError(Exception e) { + return new Result(null, Status.RESOLVE_ERROR, e); + } + } +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/ResolverUtils.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/ResolverUtils.java new file mode 100644 index 0000000000..a2424047be --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/ResolverUtils.java @@ -0,0 +1,19 @@ +package ai.lzy.longrunning.task; + +import jakarta.annotation.Nullable; + +import java.util.Map; + +public final class ResolverUtils { + private ResolverUtils() { + } + + @Nullable + public static String readString(Map meta, String key) { + var obj = meta.get(key); + if (obj instanceof String s) { + return s; + } + return null; + } +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskAwareAction.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskAwareAction.java deleted file mode 100644 index e7800c6311..0000000000 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskAwareAction.java +++ /dev/null @@ -1,58 +0,0 @@ -package ai.lzy.longrunning.task; - -import ai.lzy.longrunning.OperationRunnerBase; -import ai.lzy.longrunning.OperationsExecutor; -import ai.lzy.longrunning.dao.OperationDao; -import ai.lzy.model.db.Storage; -import jakarta.annotation.Nullable; - -import java.time.Duration; -import java.util.Map; - -public abstract class TaskAwareAction extends OperationRunnerBase { - private final TaskQueue queue; - private final Duration leaseDuration; - private Task task; - - public TaskAwareAction(Task task, TaskQueue queue, Duration leaseDuration, String opId, String desc, - Storage storage, OperationDao operationsDao, OperationsExecutor executor) - { - super(opId, desc, storage, operationsDao, executor); - this.task = task; - this.queue = queue; - this.leaseDuration = leaseDuration; - } - - @Override - protected Map prepareLogContext() { - var ctx = super.prepareLogContext(); - ctx.put("task_id", String.valueOf(task.id())); - ctx.put("task_type", task.type()); - ctx.put("task_name", task.name()); - ctx.put("task_entity_id", task.entityId()); - return ctx; - } - - protected Task task() { - return task; - } - - @Override - protected void beforeStep() { - super.beforeStep(); - task = queue.updateLease(task.id(), leaseDuration); - } - - @Override - protected void notifyFinished(@Nullable Throwable t) { - super.notifyFinished(t); - - var builder = Task.Update.builder(); - if (t != null) { - builder.status(Task.Status.FAILED); - } else { - builder.status(Task.Status.FINISHED); - } - task = queue.update(task.id(), builder.build()); - } -} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskExecutor.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskExecutor.java deleted file mode 100644 index f0f07eff7a..0000000000 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskExecutor.java +++ /dev/null @@ -1,63 +0,0 @@ -package ai.lzy.longrunning.task; - -import ai.lzy.longrunning.OperationsExecutor; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; - -import java.time.Duration; -import java.util.concurrent.Executors; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; - -public class TaskExecutor { - - private static final Logger LOG = LogManager.getLogger(TaskExecutor.class); - - private final TaskQueue taskQueue; - private final OperationsExecutor operationsExecutor; - private final TaskResolver resolver; - private final ScheduledExecutorService scheduler; - private final Duration initialDelay; - private final Duration executionDelay; - - private volatile boolean started = false; - private volatile boolean disabled = false; - - public TaskExecutor(TaskQueue taskQueue, OperationsExecutor operationsExecutor, TaskResolver resolver, - Duration initialDelay, Duration executionDelay) - { - this.taskQueue = taskQueue; - this.operationsExecutor = operationsExecutor; - this.resolver = resolver; - this.initialDelay = initialDelay; - this.executionDelay = executionDelay; - this.scheduler = Executors.newSingleThreadScheduledExecutor(); - - } - - //todo support worker failures - retry locked tasks after restart - public void start() { - if (started) { - throw new IllegalStateException("Task executor has already started!"); - } - started = true; - scheduler.scheduleWithFixedDelay(() -> { - try { - for (Task task : taskQueue.pollRemaining()) { - if (disabled) { - return; - } - var resolvedAction = resolver.resolve(task); - operationsExecutor.startNew(resolvedAction); - } - } catch (Exception e) { - LOG.error("Got exception while scheduling task", e); - } - }, initialDelay.toMillis(), executionDelay.toMillis(), TimeUnit.MILLISECONDS); - } - - public void shutdown() { - disabled = true; - scheduler.shutdown(); - } -} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskMetricsProvider.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskMetricsProvider.java new file mode 100644 index 0000000000..0efef3145f --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskMetricsProvider.java @@ -0,0 +1,9 @@ +package ai.lzy.longrunning.task; + +import io.prometheus.client.Gauge; + +public interface TaskMetricsProvider { + Gauge schedulerErrors(String instanceId); + Gauge schedulerResolveErrors(String instanceId, OperationTaskResolver.Status status); + Gauge queueSize(String instanceId); +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskQueue.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskQueue.java deleted file mode 100644 index a6c5feba09..0000000000 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskQueue.java +++ /dev/null @@ -1,21 +0,0 @@ -package ai.lzy.longrunning.task; - -import jakarta.annotation.Nullable; - -import java.time.Duration; -import java.util.List; - -public interface TaskQueue { - Task add(Task task); - - @Nullable - Task pollNext(); - - List pollRemaining(); - - Task update(long id, Task.Update update); - - Task updateLease(long taskId, Duration duration); - - void delete(long id); -} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskResolver.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskResolver.java deleted file mode 100644 index 8193886c51..0000000000 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskResolver.java +++ /dev/null @@ -1,5 +0,0 @@ -package ai.lzy.longrunning.task; - -public interface TaskResolver { - TaskAwareAction resolve(Task task); -} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TypedOperationTaskResolver.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TypedOperationTaskResolver.java new file mode 100644 index 0000000000..4a4f22bd97 --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TypedOperationTaskResolver.java @@ -0,0 +1,5 @@ +package ai.lzy.longrunning.task; + +public interface TypedOperationTaskResolver extends OperationTaskResolver { + String type(); +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/OperationTaskDao.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/OperationTaskDao.java new file mode 100644 index 0000000000..4ad8c09cbc --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/OperationTaskDao.java @@ -0,0 +1,36 @@ +package ai.lzy.longrunning.task.dao; + +import ai.lzy.longrunning.task.OperationTask; +import ai.lzy.model.db.TransactionHandle; +import jakarta.annotation.Nullable; + +import java.sql.SQLException; +import java.time.Duration; +import java.util.List; + +public interface OperationTaskDao { + + @Nullable + OperationTask get(long id, @Nullable TransactionHandle tx) throws SQLException; + + @Nullable + OperationTask update(long id, OperationTask.Update update, @Nullable TransactionHandle tx) throws SQLException; + + @Nullable + OperationTask updateLease(long id, Duration duration, @Nullable TransactionHandle tx) throws SQLException; + + @Nullable + OperationTask insert(OperationTask operationTask, @Nullable TransactionHandle tx) throws SQLException; + + void delete(long id, @Nullable TransactionHandle tx) throws SQLException; + + List lockPendingBatch(String ownerId, Duration leaseDuration, int batchSize, + @Nullable TransactionHandle tx) throws SQLException; + + List recaptureOldTasks(String ownerId, Duration leaseDuration, @Nullable TransactionHandle tx) + throws SQLException; + + @Nullable + OperationTask tryLockTask(Long taskId, String entityId, String ownerId, Duration leaseDuration, + @Nullable TransactionHandle tx) throws SQLException; +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/TaskDaoImpl.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/OperationTaskDaoImpl.java similarity index 57% rename from lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/TaskDaoImpl.java rename to lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/OperationTaskDaoImpl.java index 53065a517c..c727a916ce 100644 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/TaskDaoImpl.java +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/OperationTaskDaoImpl.java @@ -1,6 +1,6 @@ package ai.lzy.longrunning.task.dao; -import ai.lzy.longrunning.task.Task; +import ai.lzy.longrunning.task.OperationTask; import ai.lzy.model.db.DbOperation; import ai.lzy.model.db.Storage; import ai.lzy.model.db.TransactionHandle; @@ -17,29 +17,29 @@ import java.util.List; import java.util.Map; -public class TaskDaoImpl implements TaskDao { +public class OperationTaskDaoImpl implements OperationTaskDao { private static final String FIELDS = "id, name, entity_id, type, status, created_at, updated_at, metadata," + " operation_id, worker_id, lease_till"; - public static final String SELECT_QUERY = "SELECT %s FROM task WHERE id = ?".formatted(FIELDS); + public static final String SELECT_QUERY = "SELECT %s FROM operation_task WHERE id = ?".formatted(FIELDS); public static final String INSERT_QUERY = """ - INSERT INTO task (name, entity_id, type, status, created_at, updated_at, metadata, operation_id) + INSERT INTO operation_task (name, entity_id, type, status, created_at, updated_at, metadata, operation_id) VALUES (?, ?, cast(? as task_type), cast(? as task_status), now(), now(), cast(? as jsonb), ?) RETURNING %s; """.formatted(FIELDS); //in first nested request we gather all tasks that either locked or free. //in second nested request we filter result of previous request to get only free tasks - // and select only specific amount. + //and select only specific amount. private static final String LOCK_PENDING_BATCH_QUERY = """ - UPDATE task + UPDATE operation_task SET status = 'RUNNING', worker_id = ?, updated_at = now(), lease_till = now() + cast(? as interval) WHERE id IN ( SELECT id - FROM task + FROM operation_task WHERE id IN ( SELECT DISTINCT ON (entity_id) id - FROM task + FROM operation_task WHERE status IN ('PENDING', 'RUNNING') ORDER BY entity_id, id ) AND status = 'PENDING' @@ -47,9 +47,37 @@ WHERE status IN ('PENDING', 'RUNNING') ) RETURNING %s; """.formatted(FIELDS); - public static final String DELETE_QUERY = "DELETE FROM task WHERE id = ?"; + + private static final String RECAPTURE_OLD_TASKS_QUERY = """ + UPDATE operation_task + SET updated_at = now(), lease_till = now() + cast(? as interval) + WHERE status = 'RUNNING' AND worker_id = ? + RETURNING %s; + """.formatted(FIELDS); + + //in first nested request we select all tasks by entity_id that are either locked or free and take the first one. + //in second nested request we filter result of previous request to get only pending task with specific id. + //if we get the task then we lock it and return it. + private static final String TRY_LOCK_TASK_QUERY = """ + UPDATE operation_task + SET status = 'RUNNING', worker_id = ?, updated_at = now(), lease_till = now() + cast(? as interval) + WHERE id IN ( + SELECT id + FROM operation_task + WHERE id IN ( + SELECT id + FROM operation_task + WHERE status IN ('PENDING', 'RUNNING') AND entity_id = ? + ORDER BY id + LIMIT 1 + ) AND status = 'PENDING' AND id = ? + ) + RETURNING %s; + """.formatted(FIELDS); + + public static final String DELETE_QUERY = "DELETE FROM operation_task WHERE id = ?"; public static final String UPDATE_LEASE_QUERY = """ - UPDATE task + UPDATE operation_task SET lease_till = now() + cast(? as interval) WHERE id = ? RETURNING %s @@ -60,14 +88,14 @@ WHERE status IN ('PENDING', 'RUNNING') private final Storage storage; private final ObjectMapper objectMapper; - public TaskDaoImpl(Storage storage, ObjectMapper objectMapper) { + public OperationTaskDaoImpl(Storage storage, ObjectMapper objectMapper) { this.storage = storage; this.objectMapper = objectMapper; } @Nullable @Override - public Task get(long id, @Nullable TransactionHandle tx) throws SQLException { + public OperationTask get(long id, @Nullable TransactionHandle tx) throws SQLException { return DbOperation.execute(tx, storage, c -> { try (PreparedStatement ps = c.prepareStatement(SELECT_QUERY)) { ps.setLong(1, id); @@ -83,7 +111,9 @@ public Task get(long id, @Nullable TransactionHandle tx) throws SQLException { @Override @Nullable - public Task update(long id, Task.Update update, @Nullable TransactionHandle tx) throws SQLException { + public OperationTask update(long id, OperationTask.Update update, @Nullable TransactionHandle tx) + throws SQLException + { if (update.isEmpty()) { return null; } @@ -103,7 +133,7 @@ public Task update(long id, Task.Update update, @Nullable TransactionHandle tx) @Override @Nullable - public Task updateLease(long id, Duration duration, @Nullable TransactionHandle tx) throws SQLException { + public OperationTask updateLease(long id, Duration duration, @Nullable TransactionHandle tx) throws SQLException { return DbOperation.execute(tx, storage, c -> { try (PreparedStatement ps = c.prepareStatement(UPDATE_LEASE_QUERY)) { ps.setString(1, duration.toString()); @@ -119,16 +149,16 @@ public Task updateLease(long id, Duration duration, @Nullable TransactionHandle @Nullable @Override - public Task insert(Task task, @Nullable TransactionHandle tx) throws SQLException { + public OperationTask insert(OperationTask operationTask, @Nullable TransactionHandle tx) throws SQLException { return DbOperation.execute(tx, storage, c -> { try (PreparedStatement ps = c.prepareStatement(INSERT_QUERY)) { var i = 0; - ps.setString(++i, task.name()); - ps.setString(++i, task.entityId()); - ps.setString(++i, task.type()); - ps.setString(++i, task.status().name()); - ps.setString(++i, objectMapper.writeValueAsString(task.metadata())); - ps.setString(++i, task.operationId()); + ps.setString(++i, operationTask.name()); + ps.setString(++i, operationTask.entityId()); + ps.setString(++i, operationTask.type()); + ps.setString(++i, operationTask.status().name()); + ps.setString(++i, objectMapper.writeValueAsString(operationTask.metadata())); + ps.setString(++i, operationTask.operationId()); var rs = ps.executeQuery(); if (rs.next()) { return readTask(rs); @@ -151,17 +181,37 @@ public void delete(long id, @Nullable TransactionHandle tx) throws SQLException } @Override - public List lockPendingBatch(String ownerId, Duration leaseTime, int batchSize, - @Nullable TransactionHandle tx) throws SQLException + public List lockPendingBatch(String ownerId, Duration leaseDuration, int batchSize, + @Nullable TransactionHandle tx) throws SQLException { return DbOperation.execute(tx, storage, c -> { try (PreparedStatement ps = c.prepareStatement(LOCK_PENDING_BATCH_QUERY)) { var i = 0; ps.setString(++i, ownerId); - ps.setString(++i, leaseTime.toString()); + ps.setString(++i, leaseDuration.toString()); ps.setInt(++i, batchSize); try (var rs = ps.executeQuery()) { - var result = new ArrayList(rs.getFetchSize()); + var result = new ArrayList(rs.getFetchSize()); + while (rs.next()) { + result.add(readTask(rs)); + } + return result; + } + } + }); + } + + @Override + public List recaptureOldTasks(String ownerId, Duration leaseTime, @Nullable TransactionHandle tx) + throws SQLException + { + return DbOperation.execute(tx, storage, c -> { + try (PreparedStatement ps = c.prepareStatement(RECAPTURE_OLD_TASKS_QUERY)) { + var i = 0; + ps.setString(++i, leaseTime.toString()); + ps.setString(++i, ownerId); + try (var rs = ps.executeQuery()) { + var result = new ArrayList(rs.getFetchSize()); while (rs.next()) { result.add(readTask(rs)); } @@ -171,8 +221,30 @@ public List lockPendingBatch(String ownerId, Duration leaseTime, int batch }); } - private static String prepareUpdateQuery(Task.Update update) { - var sb = new StringBuilder("UPDATE task SET "); + @Nullable + @Override + public OperationTask tryLockTask(Long taskId, String entityId, String ownerId, Duration leaseDuration, + @Nullable TransactionHandle tx) throws SQLException + { + return DbOperation.execute(tx, storage, c -> { + try (PreparedStatement ps = c.prepareStatement(TRY_LOCK_TASK_QUERY)) { + var i = 0; + ps.setString(++i, ownerId); + ps.setString(++i, leaseDuration.toString()); + ps.setString(++i, entityId); + ps.setLong(++i, taskId); + try (var rs = ps.executeQuery()) { + if (rs.next()) { + return readTask(rs); + } + return null; + } + } + }); + } + + private static String prepareUpdateQuery(OperationTask.Update update) { + var sb = new StringBuilder("UPDATE operation_task SET "); if (update.status() != null) { sb.append("status = cast(? as task_status), "); } @@ -187,7 +259,7 @@ private static String prepareUpdateQuery(Task.Update update) { return sb.toString(); } - private void prepareUpdateParameters(PreparedStatement ps, long id, Task.Update update) + private void prepareUpdateParameters(PreparedStatement ps, long id, OperationTask.Update update) throws JsonProcessingException, SQLException { var i = 0; @@ -203,24 +275,26 @@ private void prepareUpdateParameters(PreparedStatement ps, long id, Task.Update ps.setLong(++i, id); } - private Task readTask(ResultSet rs) throws SQLException { + private OperationTask readTask(ResultSet rs) throws SQLException { try { var id = rs.getLong("id"); var name = rs.getString("name"); var entityId = rs.getString("entity_id"); var type = rs.getString("type"); - var status = Task.Status.valueOf(rs.getString("status")); + var status = OperationTask.Status.valueOf(rs.getString("status")); var createdAt = rs.getTimestamp("created_at").toInstant(); var updatedAt = rs.getTimestamp("updated_at").toInstant(); - var metadata = objectMapper.readValue(rs.getString("metadata"), MAP_TYPE_REFERENCE); + var metadata = objectMapper.readValue(rs.getString("metadata"), + MAP_TYPE_REFERENCE); var operationId = rs.getString("operation_id"); var workerId = rs.getString("worker_id"); var leaseTillTs = rs.getTimestamp("lease_till"); var leaseTill = leaseTillTs != null ? leaseTillTs.toInstant() : null; - return new Task(id, name, entityId, type, status, createdAt, updatedAt, metadata, operationId, workerId, + return new OperationTask(id, name, entityId, type, status, createdAt, updatedAt, metadata, operationId, + workerId, leaseTill); } catch (JsonProcessingException e) { - throw new RuntimeException("Cannot read metadata for task object", e); + throw new RuntimeException("Cannot read metadata for operation task object", e); } } diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/TaskDao.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/TaskDao.java deleted file mode 100644 index 31a92ee5d6..0000000000 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/TaskDao.java +++ /dev/null @@ -1,29 +0,0 @@ -package ai.lzy.longrunning.task.dao; - -import ai.lzy.longrunning.task.Task; -import ai.lzy.model.db.TransactionHandle; -import jakarta.annotation.Nullable; - -import java.sql.SQLException; -import java.time.Duration; -import java.util.List; - -public interface TaskDao { - - @Nullable - Task get(long id, @Nullable TransactionHandle tx) throws SQLException; - - @Nullable - Task update(long id, Task.Update update, @Nullable TransactionHandle tx) throws SQLException; - - @Nullable - Task updateLease(long id, Duration duration, @Nullable TransactionHandle tx) throws SQLException; - - @Nullable - Task insert(Task task, @Nullable TransactionHandle tx) throws SQLException; - - void delete(long id, @Nullable TransactionHandle tx) throws SQLException; - - List lockPendingBatch(String ownerId, Duration leaseTime, int batchSize, @Nullable TransactionHandle tx) - throws SQLException; -} diff --git a/lzy/long-running/src/test/java/ai/lzy/longrunning/DaoTaskQueueTest.java b/lzy/long-running/src/test/java/ai/lzy/longrunning/DaoTaskQueueTest.java deleted file mode 100644 index b1d94af886..0000000000 --- a/lzy/long-running/src/test/java/ai/lzy/longrunning/DaoTaskQueueTest.java +++ /dev/null @@ -1,129 +0,0 @@ -package ai.lzy.longrunning; - -import ai.lzy.longrunning.task.DaoTaskQueue; -import ai.lzy.longrunning.task.Task; -import ai.lzy.longrunning.task.dao.TaskDao; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import org.mockito.Mockito; - -import java.time.Duration; -import java.time.temporal.ChronoUnit; -import java.util.List; -import java.util.Map; - -public class DaoTaskQueueTest { - - private static final int MAX_QUEUE_SIZE = 10; - private static final Duration INITIAL_LEASE_TIME = Duration.of(5, ChronoUnit.MINUTES); - private static final String INSTANCE_ID = "worker"; - - private DaoTaskQueue taskQueue; - private TaskDao taskDaoMock; - - @Before - public void setup() { - taskDaoMock = Mockito.mock(TaskDao.class); - taskQueue = new DaoTaskQueue(taskDaoMock, MAX_QUEUE_SIZE, INITIAL_LEASE_TIME, - INSTANCE_ID); - } - - @Test - public void add() throws Exception { - Mockito.when(taskDaoMock.insert(Mockito.any(), Mockito.any())).thenReturn(null); - var task = Task.createPending("task", "1", "MOUNT", Map.of()); - taskQueue.add(task); - Mockito.verify(taskDaoMock, Mockito.only()).insert(task, null); - } - - @Test - public void update() throws Exception { - Mockito.when(taskDaoMock.update(Mockito.anyLong(), Mockito.any(), Mockito.any())).thenReturn(null); - var task = Task.createPending("task", "1", "MOUNT", Map.of()); - var update = Task.Update.builder().status(Task.Status.RUNNING).build(); - taskQueue.update(task.id(), update); - Mockito.verify(taskDaoMock, Mockito.only()).update(task.id(), update, null); - } - - @Test - public void updateLease() throws Exception { - Mockito.when(taskDaoMock.update(Mockito.anyLong(), Mockito.any(), Mockito.any())).thenReturn(null); - var task = Task.createPending("task", "1", "MOUNT", Map.of()); - var duration = Duration.ofSeconds(42); - taskQueue.updateLease(task.id(), duration); - Mockito.verify(taskDaoMock, Mockito.only()).updateLease(task.id(), duration, null); - } - - @Test - public void pollNext() throws Exception { - var tasks = List.of( - Task.createPending("task1", "1", "MOUNT", Map.of()), - Task.createPending("task2", "2", "MOUNT", Map.of()), - Task.createPending("task3", "3", "MOUNT", Map.of()) - ); - Mockito.when(taskDaoMock.lockPendingBatch(Mockito.anyString(), Mockito.any(), Mockito.anyInt(), Mockito.any())) - .thenReturn(tasks); - var polledTask1 = taskQueue.pollNext(); - Mockito.verify(taskDaoMock, Mockito.only()) - .lockPendingBatch(INSTANCE_ID, INITIAL_LEASE_TIME, MAX_QUEUE_SIZE, null); - var polledTask2 = taskQueue.pollNext(); - var polledTask3 = taskQueue.pollNext(); - Mockito.verify(taskDaoMock, Mockito.only()) - .lockPendingBatch(INSTANCE_ID, INITIAL_LEASE_TIME, MAX_QUEUE_SIZE, null); - Assert.assertEquals(tasks.get(0), polledTask1); - Assert.assertEquals(tasks.get(1), polledTask2); - Assert.assertEquals(tasks.get(2), polledTask3); - - var anotherBatch = List.of(Task.createPending("task4", "4", "MOUNT", Map.of())); - Mockito.when(taskDaoMock.lockPendingBatch(Mockito.anyString(), Mockito.any(), Mockito.anyInt(), Mockito.any())) - .thenReturn(anotherBatch); - var polledTask4 = taskQueue.pollNext(); - Mockito.verify(taskDaoMock, Mockito.times(2)) - .lockPendingBatch(INSTANCE_ID, INITIAL_LEASE_TIME, MAX_QUEUE_SIZE, null); - Assert.assertEquals(anotherBatch.get(0), polledTask4); - - Mockito.when(taskDaoMock.lockPendingBatch(Mockito.anyString(), Mockito.any(), Mockito.anyInt(), Mockito.any())) - .thenReturn(List.of()); - var polledTask5 = taskQueue.pollNext(); - Mockito.verify(taskDaoMock, Mockito.times(3)) - .lockPendingBatch(INSTANCE_ID, INITIAL_LEASE_TIME, MAX_QUEUE_SIZE, null); - Assert.assertNull(polledTask5); - } - - @Test - public void delete() throws Exception { - Mockito.doNothing().when(taskDaoMock).delete(Mockito.anyLong(), Mockito.any()); - var task = Task.createPending("task", "1", "MOUNT", Map.of()); - taskQueue.delete(task.id()); - Mockito.verify(taskDaoMock, Mockito.only()).delete(task.id(), null); - } - - @Test - public void pollRemaining() throws Exception { - var tasks = List.of( - Task.createPending("task1", "1", "MOUNT", Map.of()), - Task.createPending("task2", "2", "MOUNT", Map.of()), - Task.createPending("task3", "3", "MOUNT", Map.of()) - ); - Mockito.when(taskDaoMock.lockPendingBatch(Mockito.anyString(), Mockito.any(), Mockito.anyInt(), Mockito.any())) - .thenReturn(tasks); - var task = taskQueue.pollNext(); - Assert.assertEquals(tasks.get(0), task); - Mockito.verify(taskDaoMock, Mockito.only()) - .lockPendingBatch(INSTANCE_ID, INITIAL_LEASE_TIME, MAX_QUEUE_SIZE, null); - - var remainingTasks = taskQueue.pollRemaining(); - Mockito.verify(taskDaoMock, Mockito.only()) - .lockPendingBatch(INSTANCE_ID, INITIAL_LEASE_TIME, MAX_QUEUE_SIZE, null); - Assert.assertEquals(tasks.subList(1, 3), remainingTasks); - - Mockito.when(taskDaoMock.lockPendingBatch(Mockito.anyString(), Mockito.any(), Mockito.anyInt(), Mockito.any())) - .thenReturn(List.of()); - var newTasks = taskQueue.pollRemaining(); - Mockito.verify(taskDaoMock, Mockito.times(2)) - .lockPendingBatch(INSTANCE_ID, INITIAL_LEASE_TIME, MAX_QUEUE_SIZE, null); - Assert.assertEquals(List.of(), newTasks); - } - -} diff --git a/lzy/long-running/src/test/java/ai/lzy/longrunning/OperationTaskDaoImplTest.java b/lzy/long-running/src/test/java/ai/lzy/longrunning/OperationTaskDaoImplTest.java new file mode 100644 index 0000000000..1af8c09ac3 --- /dev/null +++ b/lzy/long-running/src/test/java/ai/lzy/longrunning/OperationTaskDaoImplTest.java @@ -0,0 +1,159 @@ +package ai.lzy.longrunning; + +import ai.lzy.longrunning.task.OperationTask; +import ai.lzy.longrunning.task.dao.OperationTaskDaoImpl; +import ai.lzy.model.db.StorageImpl; +import ai.lzy.model.db.test.DatabaseTestUtils; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.zonky.test.db.postgres.junit.EmbeddedPostgresRules; +import io.zonky.test.db.postgres.junit.PreparedDbRule; +import org.junit.*; + +import java.time.Duration; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class OperationTaskDaoImplTest { + + @Rule + public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(ds -> {}); + + private StorageImpl storage; + private OperationTaskDaoImpl taskDao; + + + @Before + public void setup() { + storage = new StorageImpl(DatabaseTestUtils.preparePostgresConfig(db.getConnectionInfo()), + "classpath:db/migrations") {}; + taskDao = new OperationTaskDaoImpl(storage, new ObjectMapper()); + } + + @After + public void teardown() { + DatabaseTestUtils.cleanup(storage); + storage.close(); + } + + @Test + public void create() throws Exception { + var task = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); + var fetched = taskDao.get(task.id(), null); + assertEquals(task, fetched); + } + + @Test + public void multiCreate() throws Exception { + var task1 = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); + var task2 = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); + var task3 = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); + Assert.assertTrue(task1.id() < task2.id()); + Assert.assertTrue(task2.id() < task3.id()); + } + + @Test + public void update() throws Exception { + var task = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); + var updated = taskDao.update(task.id(), OperationTask.Update.builder().status(OperationTask.Status.RUNNING).build(), null); + assertEquals(OperationTask.Status.RUNNING, updated.status()); + updated = taskDao.update(task.id(), OperationTask.Update.builder().operationId("42").build(), null); + assertEquals("42", updated.operationId()); + updated = taskDao.update(task.id(), OperationTask.Update.builder().metadata(Map.of("qux", "quux")).build(), null); + assertEquals(Map.of("qux", "quux"), updated.metadata()); + updated = taskDao.update(task.id(), OperationTask.Update.builder() + .status(OperationTask.Status.FINISHED) + .operationId("0") + .metadata(Map.of()) + .build(), null); + assertEquals(OperationTask.Status.FINISHED, updated.status()); + assertEquals("0", updated.operationId()); + assertEquals(Map.of(), updated.metadata()); + } + + @Test + public void delete() throws Exception { + var task = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); + taskDao.delete(task.id(), null); + } + + @Test + public void updateLease() throws Exception { + var task = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); + var updateLease = taskDao.updateLease(task.id(), Duration.ofMinutes(5), null); + var between = Duration.between(task.createdAt(), updateLease.leaseTill()); + //leaseTill should be around 5 minutes from now + assertTrue(between.compareTo(Duration.ofMinutes(5)) >= 0); + } + + @Test + public void getUnknown() throws Exception { + var task = taskDao.get(42, null); + Assert.assertNull(task); + } + + @Test + public void updateUnknown() throws Exception { + var updated = taskDao.update(42, OperationTask.Update.builder().status(OperationTask.Status.RUNNING).build(), null); + Assert.assertNull(updated); + } + + @Test + public void deleteUnknown() throws Exception { + taskDao.delete(42, null); + } + + @Test + public void updateLeaseUnknown() throws Exception { + var updated = taskDao.updateLease(42, Duration.ofMinutes(5), null); + Assert.assertNull(updated); + } + + @Test + public void lockPendingBatch() throws Exception { + var task1 = taskDao.insert(OperationTask.createPending("task1", "1", "MOUNT", Map.of()), null); + var task2 = taskDao.insert(OperationTask.createPending("task2", "2", "MOUNT", Map.of()), null); + var task3 = taskDao.insert(OperationTask.createPending("task3", "3", "MOUNT", Map.of()), null); + var task4 = taskDao.insert(OperationTask.createPending("task4", "4", "MOUNT", Map.of()), null); + var task5 = taskDao.insert(OperationTask.createPending("task5", "1", "MOUNT", Map.of()), null); + var task6 = taskDao.insert(OperationTask.createPending("task6", "2", "MOUNT", Map.of()), null); + var task7 = taskDao.insert(OperationTask.createPending("task7", "3", "MOUNT", Map.of()), null); + var task8 = taskDao.insert(OperationTask.createPending("task8", "4", "MOUNT", Map.of()), null); + + task2 = taskDao.update(task2.id(), statusUpdate(OperationTask.Status.RUNNING), null); + task3 = taskDao.update(task3.id(), statusUpdate(OperationTask.Status.FINISHED), null); + task4 = taskDao.update(task4.id(), statusUpdate(OperationTask.Status.FAILED), null); + + var lockedTasks = taskDao.lockPendingBatch("worker", Duration.ofMinutes(5), 10, null); + var lockedTaskIds = lockedTasks.stream().map(OperationTask::id).collect(Collectors.toSet()); + assertEquals(Set.of(task1.id(), task7.id(), task8.id()), lockedTaskIds); + } + + @Test + public void lockPendingBatchWithAllRunning() throws Exception { + var task1 = taskDao.insert(OperationTask.createPending("task1", "1", "MOUNT", Map.of()), null); + var task2 = taskDao.insert(OperationTask.createPending("task2", "2", "MOUNT", Map.of()), null); + var task3 = taskDao.insert(OperationTask.createPending("task3", "3", "MOUNT", Map.of()), null); + var task4 = taskDao.insert(OperationTask.createPending("task4", "4", "MOUNT", Map.of()), null); + var task5 = taskDao.insert(OperationTask.createPending("task5", "1", "MOUNT", Map.of()), null); + var task6 = taskDao.insert(OperationTask.createPending("task6", "2", "MOUNT", Map.of()), null); + var task7 = taskDao.insert(OperationTask.createPending("task7", "3", "MOUNT", Map.of()), null); + var task8 = taskDao.insert(OperationTask.createPending("task8", "4", "MOUNT", Map.of()), null); + + task1 = taskDao.update(task1.id(), statusUpdate(OperationTask.Status.RUNNING), null); + task2 = taskDao.update(task2.id(), statusUpdate(OperationTask.Status.RUNNING), null); + task3 = taskDao.update(task3.id(), statusUpdate(OperationTask.Status.RUNNING), null); + task4 = taskDao.update(task4.id(), statusUpdate(OperationTask.Status.RUNNING), null); + + var lockedTasks = taskDao.lockPendingBatch("worker", Duration.ofMinutes(5), 10, null); + assertTrue(lockedTasks.isEmpty()); + } + + private OperationTask.Update statusUpdate(OperationTask.Status status) { + return OperationTask.Update.builder().status(status).build(); + } + +} diff --git a/lzy/long-running/src/test/java/ai/lzy/longrunning/TaskDaoImplTest.java b/lzy/long-running/src/test/java/ai/lzy/longrunning/TaskDaoImplTest.java deleted file mode 100644 index 6d8f9d09ea..0000000000 --- a/lzy/long-running/src/test/java/ai/lzy/longrunning/TaskDaoImplTest.java +++ /dev/null @@ -1,159 +0,0 @@ -package ai.lzy.longrunning; - -import ai.lzy.longrunning.task.Task; -import ai.lzy.longrunning.task.dao.TaskDaoImpl; -import ai.lzy.model.db.StorageImpl; -import ai.lzy.model.db.test.DatabaseTestUtils; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.zonky.test.db.postgres.junit.EmbeddedPostgresRules; -import io.zonky.test.db.postgres.junit.PreparedDbRule; -import org.junit.*; - -import java.time.Duration; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -public class TaskDaoImplTest { - - @Rule - public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(ds -> {}); - - private StorageImpl storage; - private TaskDaoImpl taskDao; - - - @Before - public void setup() { - storage = new StorageImpl(DatabaseTestUtils.preparePostgresConfig(db.getConnectionInfo()), - "classpath:db/migrations") {}; - taskDao = new TaskDaoImpl(storage, new ObjectMapper()); - } - - @After - public void teardown() { - DatabaseTestUtils.cleanup(storage); - storage.close(); - } - - @Test - public void create() throws Exception { - var task = taskDao.insert(Task.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); - var fetched = taskDao.get(task.id(), null); - assertEquals(task, fetched); - } - - @Test - public void multiCreate() throws Exception { - var task1 = taskDao.insert(Task.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); - var task2 = taskDao.insert(Task.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); - var task3 = taskDao.insert(Task.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); - Assert.assertTrue(task1.id() < task2.id()); - Assert.assertTrue(task2.id() < task3.id()); - } - - @Test - public void update() throws Exception { - var task = taskDao.insert(Task.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); - var updated = taskDao.update(task.id(), Task.Update.builder().status(Task.Status.RUNNING).build(), null); - assertEquals(Task.Status.RUNNING, updated.status()); - updated = taskDao.update(task.id(), Task.Update.builder().operationId("42").build(), null); - assertEquals("42", updated.operationId()); - updated = taskDao.update(task.id(), Task.Update.builder().metadata(Map.of("qux", "quux")).build(), null); - assertEquals(Map.of("qux", "quux"), updated.metadata()); - updated = taskDao.update(task.id(), Task.Update.builder() - .status(Task.Status.FINISHED) - .operationId("0") - .metadata(Map.of()) - .build(), null); - assertEquals(Task.Status.FINISHED, updated.status()); - assertEquals("0", updated.operationId()); - assertEquals(Map.of(), updated.metadata()); - } - - @Test - public void delete() throws Exception { - var task = taskDao.insert(Task.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); - taskDao.delete(task.id(), null); - } - - @Test - public void updateLease() throws Exception { - var task = taskDao.insert(Task.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); - var updateLease = taskDao.updateLease(task.id(), Duration.ofMinutes(5), null); - var between = Duration.between(task.createdAt(), updateLease.leaseTill()); - //leaseTill should be around 5 minutes from now - assertTrue(between.compareTo(Duration.ofMinutes(5)) >= 0); - } - - @Test - public void getUnknown() throws Exception { - var task = taskDao.get(42, null); - Assert.assertNull(task); - } - - @Test - public void updateUnknown() throws Exception { - var updated = taskDao.update(42, Task.Update.builder().status(Task.Status.RUNNING).build(), null); - Assert.assertNull(updated); - } - - @Test - public void deleteUnknown() throws Exception { - taskDao.delete(42, null); - } - - @Test - public void updateLeaseUnknown() throws Exception { - var updated = taskDao.updateLease(42, Duration.ofMinutes(5), null); - Assert.assertNull(updated); - } - - @Test - public void lockPendingBatch() throws Exception { - var task1 = taskDao.insert(Task.createPending("task1", "1", "MOUNT", Map.of()), null); - var task2 = taskDao.insert(Task.createPending("task2", "2", "MOUNT", Map.of()), null); - var task3 = taskDao.insert(Task.createPending("task3", "3", "MOUNT", Map.of()), null); - var task4 = taskDao.insert(Task.createPending("task4", "4", "MOUNT", Map.of()), null); - var task5 = taskDao.insert(Task.createPending("task5", "1", "MOUNT", Map.of()), null); - var task6 = taskDao.insert(Task.createPending("task6", "2", "MOUNT", Map.of()), null); - var task7 = taskDao.insert(Task.createPending("task7", "3", "MOUNT", Map.of()), null); - var task8 = taskDao.insert(Task.createPending("task8", "4", "MOUNT", Map.of()), null); - - task2 = taskDao.update(task2.id(), statusUpdate(Task.Status.RUNNING), null); - task3 = taskDao.update(task3.id(), statusUpdate(Task.Status.FINISHED), null); - task4 = taskDao.update(task4.id(), statusUpdate(Task.Status.FAILED), null); - - var lockedTasks = taskDao.lockPendingBatch("worker", Duration.ofMinutes(5), 10, null); - var lockedTaskIds = lockedTasks.stream().map(Task::id).collect(Collectors.toSet()); - assertEquals(Set.of(task1.id(), task7.id(), task8.id()), lockedTaskIds); - } - - @Test - public void lockPendingBatchWithAllRunning() throws Exception { - var task1 = taskDao.insert(Task.createPending("task1", "1", "MOUNT", Map.of()), null); - var task2 = taskDao.insert(Task.createPending("task2", "2", "MOUNT", Map.of()), null); - var task3 = taskDao.insert(Task.createPending("task3", "3", "MOUNT", Map.of()), null); - var task4 = taskDao.insert(Task.createPending("task4", "4", "MOUNT", Map.of()), null); - var task5 = taskDao.insert(Task.createPending("task5", "1", "MOUNT", Map.of()), null); - var task6 = taskDao.insert(Task.createPending("task6", "2", "MOUNT", Map.of()), null); - var task7 = taskDao.insert(Task.createPending("task7", "3", "MOUNT", Map.of()), null); - var task8 = taskDao.insert(Task.createPending("task8", "4", "MOUNT", Map.of()), null); - - task1 = taskDao.update(task1.id(), statusUpdate(Task.Status.RUNNING), null); - task2 = taskDao.update(task2.id(), statusUpdate(Task.Status.RUNNING), null); - task3 = taskDao.update(task3.id(), statusUpdate(Task.Status.RUNNING), null); - task4 = taskDao.update(task4.id(), statusUpdate(Task.Status.RUNNING), null); - - var lockedTasks = taskDao.lockPendingBatch("worker", Duration.ofMinutes(5), 10, null); - assertTrue(lockedTasks.isEmpty()); - } - - private Task.Update statusUpdate(Task.Status status) { - return Task.Update.builder().status(status).build(); - } - -} diff --git a/lzy/long-running/src/test/resources/db/migrations/V1__task.sql b/lzy/long-running/src/test/resources/db/migrations/V1__task.sql index 7c14e8c060..fe851476b4 100644 --- a/lzy/long-running/src/test/resources/db/migrations/V1__task.sql +++ b/lzy/long-running/src/test/resources/db/migrations/V1__task.sql @@ -2,7 +2,7 @@ CREATE TYPE task_status AS ENUM ('PENDING', 'RUNNING', 'FAILED', 'FINISHED'); CREATE TYPE task_type AS ENUM ('UNMOUNT', 'MOUNT'); -CREATE TABLE IF NOT EXISTS task( +CREATE TABLE IF NOT EXISTS operation_task( id BIGSERIAL NOT NULL, name TEXT NOT NULL, entity_id TEXT NOT NULL, From 0d7e4e5ebd41cd6e8bbe44d1ebde8e6c0e55c806 Mon Sep 17 00:00:00 2001 From: Hoxo <29162640+Hoxo@users.noreply.github.com> Date: Mon, 31 Jul 2023 00:19:49 +0300 Subject: [PATCH 4/9] inc migration --- .../{V9__operation_task.sql => V10__operation_task.sql} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename lzy/allocator/src/main/resources/db/allocator/migrations/{V9__operation_task.sql => V10__operation_task.sql} (100%) diff --git a/lzy/allocator/src/main/resources/db/allocator/migrations/V9__operation_task.sql b/lzy/allocator/src/main/resources/db/allocator/migrations/V10__operation_task.sql similarity index 100% rename from lzy/allocator/src/main/resources/db/allocator/migrations/V9__operation_task.sql rename to lzy/allocator/src/main/resources/db/allocator/migrations/V10__operation_task.sql From 2c601d88022d2fa980dc9c0b7dfe569e098790bc Mon Sep 17 00:00:00 2001 From: Hoxo <29162640+Hoxo@users.noreply.github.com> Date: Mon, 31 Jul 2023 00:56:23 +0300 Subject: [PATCH 5/9] and task quota per instance --- .../longrunning/task/OpTaskAwareAction.java | 6 ++- .../task/OperationTaskExecutor.java | 50 ++++++++++++++++--- .../longrunning/task/TaskMetricsProvider.java | 7 +-- 3 files changed, 52 insertions(+), 11 deletions(-) diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OpTaskAwareAction.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OpTaskAwareAction.java index 6e262e6f1d..7e3790b916 100644 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OpTaskAwareAction.java +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OpTaskAwareAction.java @@ -13,18 +13,20 @@ import static ai.lzy.model.db.DbHelper.withRetries; public abstract class OpTaskAwareAction extends OperationRunnerBase { + private final OperationTaskExecutor operationTaskExecutor; private final OperationTaskDao operationTaskDao; private final Duration leaseDuration; private OperationTask operationTask; public OpTaskAwareAction(OperationTask operationTask, OperationTaskDao operationTaskDao, Duration leaseDuration, String opId, String desc, Storage storage, OperationDao operationsDao, - OperationsExecutor executor) + OperationsExecutor executor, OperationTaskExecutor operationTaskExecutor) { super(opId, desc, storage, operationsDao, executor); this.operationTask = operationTask; this.operationTaskDao = operationTaskDao; this.leaseDuration = leaseDuration; + this.operationTaskExecutor = operationTaskExecutor; } @Override @@ -52,6 +54,7 @@ protected void beforeStep() { operationTask = withRetries(log(), () -> operationTaskDao.updateLease(operationTask.id(), leaseDuration, null)); } catch (Exception e) { + log().error("{} Couldn't update lease on task {}", logPrefix(), task().id()); throw new RuntimeException(e); } } @@ -72,5 +75,6 @@ protected void notifyFinished(@Nullable Throwable t) { } catch (Exception e) { log().error("{} Couldn't finish operation task {}", logPrefix(), task().id()); } + operationTaskExecutor.releaseTask(task()); } } diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskExecutor.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskExecutor.java index 4ab0a53943..ab6df96c0b 100644 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskExecutor.java +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskExecutor.java @@ -16,6 +16,7 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; public class OperationTaskExecutor { @@ -32,6 +33,7 @@ public class OperationTaskExecutor { private final String instanceId; private final Duration leaseDuration; private final int batchSize; + private final AtomicInteger runningTaskQuota = new AtomicInteger(); private volatile boolean started = false; private volatile boolean disabled = false; @@ -39,7 +41,7 @@ public class OperationTaskExecutor { public OperationTaskExecutor(OperationTaskDao opTaskDao, OperationsExecutor operationsExecutor, OperationTaskResolver resolver, Duration initialDelay, Duration executionDelay, Storage storage, TaskMetricsProvider metricsProvider, String instanceId, - Duration leaseDuration, int batchSize) + Duration leaseDuration, int batchSize, int maxRunningTasks) { this.opTaskDao = opTaskDao; this.operationsExecutor = operationsExecutor; @@ -52,6 +54,7 @@ public OperationTaskExecutor(OperationTaskDao opTaskDao, OperationsExecutor oper this.scheduler = Executors.newSingleThreadScheduledExecutor(); //it's important to have only one thread this.metricsProvider = metricsProvider; this.instanceId = instanceId; + this.runningTaskQuota.set(maxRunningTasks); } public void start() { @@ -63,15 +66,36 @@ public void start() { startMailLoop(); } - //todo backpressure - to not start new tasks if there are too many of them + private void acquireTask() { + acquireTasks(1); + } + + private void acquireTasks(int count) { + var newQuota = runningTaskQuota.addAndGet(-count); + LOG.debug("Acquired {} tasks to run. Current quota: {}", count, newQuota); + metricsProvider.runningTasks().inc(count); + } + + public void releaseTask(OperationTask task) { + LOG.debug("Finishing task {}", task.id()); + runningTaskQuota.incrementAndGet(); + metricsProvider.runningTasks().dec(); + } + private ScheduledFuture startMailLoop() { return scheduler.scheduleWithFixedDelay(() -> { try { var actions = new ArrayList(); + if (!hasQuota()) { + LOG.info("Not enough quota to start new operation tasks"); + return; + } + var quota = runningTaskQuota.get(); + var toLoad = Math.min(batchSize, quota); DbHelper.withRetries(LOG, () -> { try (var tx = TransactionHandle.create(storage)) { for (OperationTask operationTask : opTaskDao.lockPendingBatch(instanceId, leaseDuration, - batchSize, tx)) + toLoad, tx)) { if (disabled) { return; @@ -84,10 +108,11 @@ private ScheduledFuture startMailLoop() { tx.commit(); } }); + acquireTasks(actions.size()); actions.forEach(operationsExecutor::startNew); } catch (Exception e) { LOG.error("Got exception while scheduling task", e); - metricsProvider.schedulerErrors(instanceId).inc(); + metricsProvider.schedulerErrors().inc(); } }, initialDelay.toMillis(), executionDelay.toMillis(), TimeUnit.MILLISECONDS); } @@ -96,7 +121,7 @@ private ScheduledFuture startMailLoop() { private OpTaskAwareAction resolveTask(OperationTask operationTask, TransactionHandle tx) throws SQLException { var resolveResult = resolver.resolve(operationTask, tx); if (resolveResult.status() != OperationTaskResolver.Status.SUCCESS) { - metricsProvider.schedulerResolveErrors(instanceId, resolveResult.status()); + metricsProvider.schedulerResolveErrors(resolveResult.status()).inc(); } switch (resolveResult.status()) { case SUCCESS -> { @@ -124,6 +149,7 @@ private OpTaskAwareAction resolveTask(OperationTask operationTask, TransactionHa private void restoreTasks() { try { var actionsToRun = new ArrayList(); + //allow over-quoting to complete unfinished tasks first DbHelper.withRetries(LOG, () -> { try (var tx = TransactionHandle.create(storage)) { var tasks = opTaskDao.recaptureOldTasks(instanceId, leaseDuration, tx); @@ -136,10 +162,11 @@ private void restoreTasks() { tx.commit(); } }); + acquireTasks(actionsToRun.size()); actionsToRun.forEach(operationsExecutor::startNew); } catch (Exception e) { LOG.error("Got exception while restoring tasks", e); - metricsProvider.schedulerErrors(instanceId).inc(); + metricsProvider.schedulerErrors().inc(); } } @@ -160,6 +187,10 @@ public ScheduledFuture startImmediately(OperationTask opTask) { //if task is already captured by main loop, this runnable will exit return scheduler.schedule(() -> { try { + if (!hasQuota()) { + LOG.debug("Not enough quota to start immediate operation task {}", opTask.id()); + return false; + } var action = DbHelper.withRetries(LOG, () -> { try (var tx = TransactionHandle.create(storage)) { var lockedTask = opTaskDao.tryLockTask(opTask.id(), @@ -176,16 +207,21 @@ public ScheduledFuture startImmediately(OperationTask opTask) { if (action == null) { return false; } + acquireTask(); operationsExecutor.startNew(action); return true; } catch (Exception e) { LOG.error("Got exception while scheduling task", e); - metricsProvider.schedulerErrors(instanceId).inc(); + metricsProvider.schedulerErrors().inc(); return false; } }, 0, TimeUnit.MILLISECONDS); } + private boolean hasQuota() { + return runningTaskQuota.get() > 0; + } + public void shutdown() { disabled = true; scheduler.shutdown(); diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskMetricsProvider.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskMetricsProvider.java index 0efef3145f..448ebb6b89 100644 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskMetricsProvider.java +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskMetricsProvider.java @@ -3,7 +3,8 @@ import io.prometheus.client.Gauge; public interface TaskMetricsProvider { - Gauge schedulerErrors(String instanceId); - Gauge schedulerResolveErrors(String instanceId, OperationTaskResolver.Status status); - Gauge queueSize(String instanceId); + Gauge schedulerErrors(); + Gauge schedulerResolveErrors(OperationTaskResolver.Status status); + Gauge queueSize(); + Gauge runningTasks(); } From 486360bd412f9196148a7ca57e588d89baa085fb Mon Sep 17 00:00:00 2001 From: Hoxo <29162640+Hoxo@users.noreply.github.com> Date: Mon, 31 Jul 2023 13:42:43 +0300 Subject: [PATCH 6/9] add tests --- .../lzy/allocator/alloc/AllocateVmAction.java | 2 +- .../alloc/MountDynamicDiskAction.java | 2 +- .../alloc/UnmountDynamicDiskAction.java | 2 +- .../lzy/longrunning/OperationRunnerBase.java | 19 ++- .../DispatchingOperationTaskResolver.java | 6 + .../longrunning/task/OpTaskAwareAction.java | 7 +- .../lzy/longrunning/task/OperationTask.java | 6 +- .../task/OperationTaskExecutor.java | 10 +- .../longrunning/OperationTaskDaoImplTest.java | 46 ++--- .../task/OperationTaskExecutorTest.java | 157 ++++++++++++++++++ .../longrunning/task/StubMetricsProvider.java | 38 +++++ .../ai/lzy/longrunning/task/TestAction.java | 51 ++++++ .../test/resources/db/migrations/V1__task.sql | 28 +++- 13 files changed, 329 insertions(+), 45 deletions(-) create mode 100644 lzy/long-running/src/test/java/ai/lzy/longrunning/task/OperationTaskExecutorTest.java create mode 100644 lzy/long-running/src/test/java/ai/lzy/longrunning/task/StubMetricsProvider.java create mode 100644 lzy/long-running/src/test/java/ai/lzy/longrunning/task/TestAction.java diff --git a/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/AllocateVmAction.java b/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/AllocateVmAction.java index bb6625d4e8..0f1f0befc2 100644 --- a/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/AllocateVmAction.java +++ b/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/AllocateVmAction.java @@ -67,7 +67,7 @@ protected void notifyExpired() { } @Override - protected void notifyFinished(@Nullable Throwable t) { + protected void notifyFinished() { allocationContext.metrics().runningAllocations.labels(vm.poolLabel()).dec(); if (deleteVmAction != null) { diff --git a/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/MountDynamicDiskAction.java b/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/MountDynamicDiskAction.java index f097e53c79..97378fd0bd 100644 --- a/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/MountDynamicDiskAction.java +++ b/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/MountDynamicDiskAction.java @@ -68,7 +68,7 @@ protected boolean isInjectedError(Error e) { } @Override - protected void notifyFinished(@Nullable Throwable t) { + protected void notifyFinished() { if (unmountAction != null) { log().error("{} Failed to mount dynamic disk", logPrefix()); try { diff --git a/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/UnmountDynamicDiskAction.java b/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/UnmountDynamicDiskAction.java index 84a8d5c064..8294a2aec9 100644 --- a/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/UnmountDynamicDiskAction.java +++ b/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/UnmountDynamicDiskAction.java @@ -59,7 +59,7 @@ public static String description(@Nullable Vm vm, DynamicMount mount) { } @Override - protected void notifyFinished(@Nullable Throwable t) { + protected void notifyFinished() { log().info("{} Finished unmounting volume {}", logPrefix(), dynamicMount.id()); } diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/OperationRunnerBase.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/OperationRunnerBase.java index 15d4e70543..61ef5bcda7 100644 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/OperationRunnerBase.java +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/OperationRunnerBase.java @@ -32,6 +32,7 @@ public abstract class OperationRunnerBase extends ContextAwareTask { private final OperationDao operationsDao; private final OperationsExecutor executor; private Operation op; + private volatile boolean failed = false; protected OperationRunnerBase(String id, String descr, Storage storage, OperationDao operationsDao, OperationsExecutor executor) @@ -84,7 +85,8 @@ protected final void execute() { } } } catch (Throwable e) { - notifyFinished(e); + setFailed(); + notifyFinished(); if (e instanceof Error err && isInjectedError(err)) { log.error("{} Terminated by InjectedFailure exception: {}", logPrefix, e.getMessage()); } else { @@ -99,6 +101,14 @@ protected final void execute() { } } + protected void setFailed() { + failed = true; + } + + protected boolean isFailed() { + return failed; + } + protected void beforeStep() { } @@ -277,14 +287,11 @@ protected void onNotFound(@Nullable TransactionHandle tx) throws SQLException { protected void onCompletedOutside(Operation op, @Nullable TransactionHandle tx) throws SQLException { } - private void notifyFinished() { - notifyFinished(null); - } - - protected void notifyFinished(@Nullable Throwable t) { + protected void notifyFinished() { } protected final void failOperation(Status status, @Nullable TransactionHandle tx) throws SQLException { + setFailed(); operationsDao.fail(id, toProto(status), tx); } diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/DispatchingOperationTaskResolver.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/DispatchingOperationTaskResolver.java index 83c1d83621..8656b72de2 100644 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/DispatchingOperationTaskResolver.java +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/DispatchingOperationTaskResolver.java @@ -1,6 +1,7 @@ package ai.lzy.longrunning.task; import ai.lzy.model.db.TransactionHandle; +import com.google.common.annotations.VisibleForTesting; import jakarta.annotation.Nullable; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -33,6 +34,11 @@ private static Map generateResolversMap( .collect(Collectors.toMap(TypedOperationTaskResolver::type, r -> r)); } + @VisibleForTesting + void addResolver(TypedOperationTaskResolver resolver) { + resolvers.put(resolver.type(), resolver); + } + @Override public Result resolve(OperationTask operationTask, @Nullable TransactionHandle tx) throws SQLException { var resolver = resolvers.get(operationTask.type()); diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OpTaskAwareAction.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OpTaskAwareAction.java index 7e3790b916..fc9a30bbea 100644 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OpTaskAwareAction.java +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OpTaskAwareAction.java @@ -5,7 +5,6 @@ import ai.lzy.longrunning.dao.OperationDao; import ai.lzy.longrunning.task.dao.OperationTaskDao; import ai.lzy.model.db.Storage; -import jakarta.annotation.Nullable; import java.time.Duration; import java.util.Map; @@ -60,11 +59,9 @@ protected void beforeStep() { } @Override - protected void notifyFinished(@Nullable Throwable t) { - super.notifyFinished(t); - + protected void notifyFinished() { var builder = OperationTask.Update.builder(); - if (t != null) { + if (isFailed()) { builder.status(OperationTask.Status.FAILED); } else { builder.status(OperationTask.Status.FINISHED); diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTask.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTask.java index 8fc130a2d3..8e4ef3dbaf 100644 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTask.java +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTask.java @@ -21,9 +21,11 @@ public record OperationTask( @Nullable Instant leaseTill ) { - public static OperationTask createPending(String name, String entityId, String type, Map metadata) { + public static OperationTask createPending(String name, String entityId, String type, Map metadata, + String operationId) + { return new OperationTask(-1, name, entityId, type, Status.PENDING, Instant.now(), Instant.now(), - metadata, null, null, null); + metadata, operationId, null, null); } public enum Status { diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskExecutor.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskExecutor.java index ab6df96c0b..4e2744c2d1 100644 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskExecutor.java +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskExecutor.java @@ -16,6 +16,7 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; public class OperationTaskExecutor { @@ -35,7 +36,7 @@ public class OperationTaskExecutor { private final int batchSize; private final AtomicInteger runningTaskQuota = new AtomicInteger(); - private volatile boolean started = false; + private final AtomicBoolean started = new AtomicBoolean(false); private volatile boolean disabled = false; public OperationTaskExecutor(OperationTaskDao opTaskDao, OperationsExecutor operationsExecutor, @@ -58,10 +59,9 @@ public OperationTaskExecutor(OperationTaskDao opTaskDao, OperationsExecutor oper } public void start() { - if (started) { + if (!started.compareAndSet(false, true)) { throw new IllegalStateException("Task executor has already started!"); } - started = true; restoreTasks(); startMailLoop(); } @@ -178,8 +178,8 @@ private OperationTask setStatus(OperationTask operationTask, OperationTask.Statu .build(), tx); } - public void saveTask(OperationTask operationTask, @Nullable TransactionHandle tx) throws SQLException { - opTaskDao.insert(operationTask, tx); + public OperationTask saveTask(OperationTask operationTask, @Nullable TransactionHandle tx) throws SQLException { + return opTaskDao.insert(operationTask, tx); } public ScheduledFuture startImmediately(OperationTask opTask) { diff --git a/lzy/long-running/src/test/java/ai/lzy/longrunning/OperationTaskDaoImplTest.java b/lzy/long-running/src/test/java/ai/lzy/longrunning/OperationTaskDaoImplTest.java index 1af8c09ac3..c0a7da1a1d 100644 --- a/lzy/long-running/src/test/java/ai/lzy/longrunning/OperationTaskDaoImplTest.java +++ b/lzy/long-running/src/test/java/ai/lzy/longrunning/OperationTaskDaoImplTest.java @@ -41,23 +41,23 @@ public void teardown() { @Test public void create() throws Exception { - var task = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); + var task = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar"), null), null); var fetched = taskDao.get(task.id(), null); assertEquals(task, fetched); } @Test public void multiCreate() throws Exception { - var task1 = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); - var task2 = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); - var task3 = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); + var task1 = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar"), null), null); + var task2 = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar"), null), null); + var task3 = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar"), null), null); Assert.assertTrue(task1.id() < task2.id()); Assert.assertTrue(task2.id() < task3.id()); } @Test public void update() throws Exception { - var task = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); + var task = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar"), null), null); var updated = taskDao.update(task.id(), OperationTask.Update.builder().status(OperationTask.Status.RUNNING).build(), null); assertEquals(OperationTask.Status.RUNNING, updated.status()); updated = taskDao.update(task.id(), OperationTask.Update.builder().operationId("42").build(), null); @@ -76,13 +76,13 @@ public void update() throws Exception { @Test public void delete() throws Exception { - var task = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); + var task = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar"), null), null); taskDao.delete(task.id(), null); } @Test public void updateLease() throws Exception { - var task = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar")), null); + var task = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar"), null), null); var updateLease = taskDao.updateLease(task.id(), Duration.ofMinutes(5), null); var between = Duration.between(task.createdAt(), updateLease.leaseTill()); //leaseTill should be around 5 minutes from now @@ -114,14 +114,14 @@ public void updateLeaseUnknown() throws Exception { @Test public void lockPendingBatch() throws Exception { - var task1 = taskDao.insert(OperationTask.createPending("task1", "1", "MOUNT", Map.of()), null); - var task2 = taskDao.insert(OperationTask.createPending("task2", "2", "MOUNT", Map.of()), null); - var task3 = taskDao.insert(OperationTask.createPending("task3", "3", "MOUNT", Map.of()), null); - var task4 = taskDao.insert(OperationTask.createPending("task4", "4", "MOUNT", Map.of()), null); - var task5 = taskDao.insert(OperationTask.createPending("task5", "1", "MOUNT", Map.of()), null); - var task6 = taskDao.insert(OperationTask.createPending("task6", "2", "MOUNT", Map.of()), null); - var task7 = taskDao.insert(OperationTask.createPending("task7", "3", "MOUNT", Map.of()), null); - var task8 = taskDao.insert(OperationTask.createPending("task8", "4", "MOUNT", Map.of()), null); + var task1 = taskDao.insert(OperationTask.createPending("task1", "1", "MOUNT", Map.of(), null), null); + var task2 = taskDao.insert(OperationTask.createPending("task2", "2", "MOUNT", Map.of(), null), null); + var task3 = taskDao.insert(OperationTask.createPending("task3", "3", "MOUNT", Map.of(), null), null); + var task4 = taskDao.insert(OperationTask.createPending("task4", "4", "MOUNT", Map.of(), null), null); + var task5 = taskDao.insert(OperationTask.createPending("task5", "1", "MOUNT", Map.of(), null), null); + var task6 = taskDao.insert(OperationTask.createPending("task6", "2", "MOUNT", Map.of(), null), null); + var task7 = taskDao.insert(OperationTask.createPending("task7", "3", "MOUNT", Map.of(), null), null); + var task8 = taskDao.insert(OperationTask.createPending("task8", "4", "MOUNT", Map.of(), null), null); task2 = taskDao.update(task2.id(), statusUpdate(OperationTask.Status.RUNNING), null); task3 = taskDao.update(task3.id(), statusUpdate(OperationTask.Status.FINISHED), null); @@ -134,14 +134,14 @@ public void lockPendingBatch() throws Exception { @Test public void lockPendingBatchWithAllRunning() throws Exception { - var task1 = taskDao.insert(OperationTask.createPending("task1", "1", "MOUNT", Map.of()), null); - var task2 = taskDao.insert(OperationTask.createPending("task2", "2", "MOUNT", Map.of()), null); - var task3 = taskDao.insert(OperationTask.createPending("task3", "3", "MOUNT", Map.of()), null); - var task4 = taskDao.insert(OperationTask.createPending("task4", "4", "MOUNT", Map.of()), null); - var task5 = taskDao.insert(OperationTask.createPending("task5", "1", "MOUNT", Map.of()), null); - var task6 = taskDao.insert(OperationTask.createPending("task6", "2", "MOUNT", Map.of()), null); - var task7 = taskDao.insert(OperationTask.createPending("task7", "3", "MOUNT", Map.of()), null); - var task8 = taskDao.insert(OperationTask.createPending("task8", "4", "MOUNT", Map.of()), null); + var task1 = taskDao.insert(OperationTask.createPending("task1", "1", "MOUNT", Map.of(), null), null); + var task2 = taskDao.insert(OperationTask.createPending("task2", "2", "MOUNT", Map.of(), null), null); + var task3 = taskDao.insert(OperationTask.createPending("task3", "3", "MOUNT", Map.of(), null), null); + var task4 = taskDao.insert(OperationTask.createPending("task4", "4", "MOUNT", Map.of(), null), null); + var task5 = taskDao.insert(OperationTask.createPending("task5", "1", "MOUNT", Map.of(), null), null); + var task6 = taskDao.insert(OperationTask.createPending("task6", "2", "MOUNT", Map.of(), null), null); + var task7 = taskDao.insert(OperationTask.createPending("task7", "3", "MOUNT", Map.of(), null), null); + var task8 = taskDao.insert(OperationTask.createPending("task8", "4", "MOUNT", Map.of(), null), null); task1 = taskDao.update(task1.id(), statusUpdate(OperationTask.Status.RUNNING), null); task2 = taskDao.update(task2.id(), statusUpdate(OperationTask.Status.RUNNING), null); diff --git a/lzy/long-running/src/test/java/ai/lzy/longrunning/task/OperationTaskExecutorTest.java b/lzy/long-running/src/test/java/ai/lzy/longrunning/task/OperationTaskExecutorTest.java new file mode 100644 index 0000000000..14bee02fbe --- /dev/null +++ b/lzy/long-running/src/test/java/ai/lzy/longrunning/task/OperationTaskExecutorTest.java @@ -0,0 +1,157 @@ +package ai.lzy.longrunning.task; + +import ai.lzy.longrunning.Operation; +import ai.lzy.longrunning.OperationRunnerBase; +import ai.lzy.longrunning.OperationsExecutor; +import ai.lzy.longrunning.dao.OperationDaoImpl; +import ai.lzy.longrunning.task.dao.OperationTaskDaoImpl; +import ai.lzy.model.db.StorageImpl; +import ai.lzy.model.db.TransactionHandle; +import ai.lzy.model.db.test.DatabaseTestUtils; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.zonky.test.db.postgres.junit.EmbeddedPostgresRules; +import io.zonky.test.db.postgres.junit.PreparedDbRule; +import jakarta.annotation.Nullable; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; + +import java.sql.SQLException; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.locks.LockSupport; +import java.util.function.Supplier; + +import static org.junit.Assert.*; + +public class OperationTaskExecutorTest { + + public static final String MOUNT_TASK_TYPE = "MOUNT"; + public static final Duration SCHEDULER_DELAY = Duration.ofMinutes(5); + public static final Duration LEASE_DURATION = Duration.ofMinutes(5); + public static final int BATCH_SIZE = 7; + public static final int MAX_RUNNING_TASKS = 10; + public static final String WORKER_ID = "worker-42"; + + @Rule + public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(ds -> {}); + + @Rule + public Timeout timeout = Timeout.seconds(30); + + private StorageImpl storage; + private OperationTaskDaoImpl taskDao; + private OperationTaskExecutor taskExecutor; + private OperationDaoImpl operationDao; + private OperationsExecutor operationsExecutor; + private DispatchingOperationTaskResolver taskResolver; + + + @Before + public void setup() { + storage = new StorageImpl(DatabaseTestUtils.preparePostgresConfig(db.getConnectionInfo()), + "classpath:db/migrations") {}; + taskDao = new OperationTaskDaoImpl(storage, new ObjectMapper()); + operationDao = new OperationDaoImpl(storage); + operationsExecutor = new OperationsExecutor(5, 10, () -> {}, e -> false); + taskResolver = new DispatchingOperationTaskResolver(List.of()); + taskExecutor = new OperationTaskExecutor(taskDao, operationsExecutor, taskResolver, Duration.ZERO, + SCHEDULER_DELAY, storage, new StubMetricsProvider(), WORKER_ID , LEASE_DURATION, + BATCH_SIZE, MAX_RUNNING_TASKS); + + } + + @After + public void teardown() { + DatabaseTestUtils.cleanup(storage); + storage.close(); + taskExecutor.shutdown(); + } + + @Test + public void executorWorkflow() { + taskExecutor.start(); + } + + @Test + public void executorCannotBeStartedTwice() { + taskExecutor.start(); + assertThrows(IllegalStateException.class, () -> taskExecutor.start()); + } + + @Test + public void executorShouldWork() throws SQLException { + taskResolver.addResolver(resolver(MOUNT_TASK_TYPE, () -> OperationRunnerBase.StepResult.FINISH, true)); + var op = createOperation(); + var task = taskExecutor.saveTask(OperationTask.createPending("Test", "foo", + MOUNT_TASK_TYPE, Map.of(), op.id()), null); + taskExecutor.start(); + op = waitForOperation(op.id()); + assertNull(op.error()); + task = taskDao.get(task.id(), null); + assertNotNull(task); + assertEquals(OperationTask.Status.FINISHED, task.status()); + assertTrue(task.createdAt().isBefore(task.updatedAt())); + assertNotNull(task.leaseTill()); + assertTrue(task.createdAt().isBefore(task.leaseTill())); + assertEquals(WORKER_ID, task.workerId()); + } + + @Test + public void executorCanFail() throws SQLException { + taskResolver.addResolver(resolver(MOUNT_TASK_TYPE, () -> OperationRunnerBase.StepResult.FINISH, false)); + var op = createOperation(); + var task = taskExecutor.saveTask(OperationTask.createPending("Test", "foo", + MOUNT_TASK_TYPE, Map.of(), op.id()), null); + taskExecutor.start(); + op = waitForOperation(op.id()); + assertNotNull(op.error()); + task = taskDao.get(task.id(), null); + assertNotNull(task); + assertEquals(OperationTask.Status.FAILED, task.status()); + assertTrue(task.createdAt().isBefore(task.updatedAt())); + assertNotNull(task.leaseTill()); + assertTrue(task.createdAt().isBefore(task.leaseTill())); + assertEquals(WORKER_ID, task.workerId()); + } + + private Operation waitForOperation(String opId) throws SQLException { + while (true) { + var op = operationDao.get(opId, null); + if (op == null) { + fail("Operation " + opId + " cannot be null"); + } + if (op.done()) { + return op; + } + LockSupport.parkNanos(Duration.ofSeconds(1).toNanos()); + } + } + + private Operation createOperation() throws SQLException { + var operation = Operation.create("foo", "op", Duration.ofDays(1), null, null); + operationDao.create(operation, null); + return operation; + } + + private TypedOperationTaskResolver resolver(String type, + Supplier action, + boolean completeOperation) { + return new TypedOperationTaskResolver() { + @Override + public String type() { + return type; + } + + @Override + public Result resolve(OperationTask task, @Nullable TransactionHandle tx) { + return Result.success(new TestAction(action, completeOperation, task, taskDao, + SCHEDULER_DELAY, task.operationId(), "Test action", storage, operationDao, + operationsExecutor, taskExecutor)); + } + }; + } +} diff --git a/lzy/long-running/src/test/java/ai/lzy/longrunning/task/StubMetricsProvider.java b/lzy/long-running/src/test/java/ai/lzy/longrunning/task/StubMetricsProvider.java new file mode 100644 index 0000000000..b1a259d722 --- /dev/null +++ b/lzy/long-running/src/test/java/ai/lzy/longrunning/task/StubMetricsProvider.java @@ -0,0 +1,38 @@ +package ai.lzy.longrunning.task; + +import io.prometheus.client.Gauge; + +public class StubMetricsProvider implements TaskMetricsProvider { + + @Override + public Gauge schedulerErrors() { + return Gauge.build() + .name("scheduler_errors") + .help("help") + .create(); + } + + @Override + public Gauge schedulerResolveErrors(OperationTaskResolver.Status status) { + return Gauge.build() + .name("scheduler_resolve_errors") + .help("help") + .create(); + } + + @Override + public Gauge queueSize() { + return Gauge.build() + .name("queue_size") + .help("help") + .create(); + } + + @Override + public Gauge runningTasks() { + return Gauge.build() + .name("running_tasks") + .help("help") + .create(); + } +} diff --git a/lzy/long-running/src/test/java/ai/lzy/longrunning/task/TestAction.java b/lzy/long-running/src/test/java/ai/lzy/longrunning/task/TestAction.java new file mode 100644 index 0000000000..2cd4419f96 --- /dev/null +++ b/lzy/long-running/src/test/java/ai/lzy/longrunning/task/TestAction.java @@ -0,0 +1,51 @@ +package ai.lzy.longrunning.task; + +import ai.lzy.longrunning.OperationsExecutor; +import ai.lzy.longrunning.dao.OperationDao; +import ai.lzy.longrunning.task.dao.OperationTaskDao; +import ai.lzy.model.db.Storage; +import com.google.protobuf.Any; +import com.google.protobuf.Empty; +import io.grpc.Status; + +import java.sql.SQLException; +import java.time.Duration; +import java.util.List; +import java.util.function.Supplier; + +public class TestAction extends OpTaskAwareAction { + + private final Supplier action; + private final boolean completeOperation; + + public TestAction(Supplier action, boolean completeOperation, OperationTask operationTask, + OperationTaskDao operationTaskDao, Duration leaseDuration, String opId, String desc, + Storage storage, OperationDao operationsDao, + OperationsExecutor executor, + OperationTaskExecutor operationTaskExecutor) + { + super(operationTask, operationTaskDao, leaseDuration, opId, desc, storage, operationsDao, executor, + operationTaskExecutor); + this.action = action; + this.completeOperation = completeOperation; + } + + @Override + protected List> steps() { + return List.of(this::doSmth); + } + + private StepResult doSmth() { + var stepResult = action.get(); + try { + if (completeOperation) { + completeOperation(null, Any.pack(Empty.getDefaultInstance()), null); + } else { + failOperation(Status.INTERNAL, null); + } + } catch (SQLException e) { + log().error("{} Error while completing operation", logPrefix()); + } + return stepResult; + } +} diff --git a/lzy/long-running/src/test/resources/db/migrations/V1__task.sql b/lzy/long-running/src/test/resources/db/migrations/V1__task.sql index fe851476b4..925cbf506e 100644 --- a/lzy/long-running/src/test/resources/db/migrations/V1__task.sql +++ b/lzy/long-running/src/test/resources/db/migrations/V1__task.sql @@ -1,3 +1,28 @@ +CREATE TABLE operation +( + id TEXT NOT NULL PRIMARY KEY, + created_by TEXT NOT NULL, + created_at TIMESTAMP NOT NULL, + modified_at TIMESTAMP NOT NULL, + description TEXT NOT NULL, + deadline TIMESTAMP NULL, + done BOOLEAN NOT NULL, + + meta BYTEA NULL, + response BYTEA NULL, + error BYTEA NULL, + + idempotency_key TEXT NULL, + request_hash TEXT NULL, + + CHECK (((idempotency_key IS NOT NULL) AND (request_hash IS NOT NULL)) OR + ((idempotency_key IS NULL) AND (request_hash IS NULL))) +); + +CREATE UNIQUE INDEX idempotency_key_to_operation_index ON operation (idempotency_key); +CREATE UNIQUE INDEX failed_operations_index ON operation (id) WHERE done = TRUE AND error IS NOT NULL; +CREATE UNIQUE INDEX completed_operations_index ON operation (id) WHERE done = TRUE; + CREATE TYPE task_status AS ENUM ('PENDING', 'RUNNING', 'FAILED', 'FINISHED'); CREATE TYPE task_type AS ENUM ('UNMOUNT', 'MOUNT'); @@ -14,5 +39,6 @@ CREATE TABLE IF NOT EXISTS operation_task( operation_id TEXT, worker_id TEXT, lease_till TIMESTAMP, - PRIMARY KEY (id) + PRIMARY KEY (id), + FOREIGN KEY (operation_id) REFERENCES operation(id) ); \ No newline at end of file From d55a53e1ef4b5cf33eb94a2df794a6c0e368b2ce Mon Sep 17 00:00:00 2001 From: Hoxo <29162640+Hoxo@users.noreply.github.com> Date: Mon, 31 Jul 2023 14:09:19 +0300 Subject: [PATCH 7/9] renaming --- .../longrunning/task/OpTaskAwareAction.java | 8 +-- ...cutor.java => OperationTaskScheduler.java} | 12 ++-- .../task/dao/OperationTaskDao.java | 2 + .../task/dao/OperationTaskDaoImpl.java | 18 ++++++ ...t.java => OperationTaskSchedulerTest.java} | 59 +++++++++++++------ .../ai/lzy/longrunning/task/TestAction.java | 4 +- 6 files changed, 74 insertions(+), 29 deletions(-) rename lzy/long-running/src/main/java/ai/lzy/longrunning/task/{OperationTaskExecutor.java => OperationTaskScheduler.java} (94%) rename lzy/long-running/src/test/java/ai/lzy/longrunning/task/{OperationTaskExecutorTest.java => OperationTaskSchedulerTest.java} (71%) diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OpTaskAwareAction.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OpTaskAwareAction.java index fc9a30bbea..d202521fd2 100644 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OpTaskAwareAction.java +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OpTaskAwareAction.java @@ -12,20 +12,20 @@ import static ai.lzy.model.db.DbHelper.withRetries; public abstract class OpTaskAwareAction extends OperationRunnerBase { - private final OperationTaskExecutor operationTaskExecutor; + private final OperationTaskScheduler operationTaskScheduler; private final OperationTaskDao operationTaskDao; private final Duration leaseDuration; private OperationTask operationTask; public OpTaskAwareAction(OperationTask operationTask, OperationTaskDao operationTaskDao, Duration leaseDuration, String opId, String desc, Storage storage, OperationDao operationsDao, - OperationsExecutor executor, OperationTaskExecutor operationTaskExecutor) + OperationsExecutor executor, OperationTaskScheduler operationTaskScheduler) { super(opId, desc, storage, operationsDao, executor); this.operationTask = operationTask; this.operationTaskDao = operationTaskDao; this.leaseDuration = leaseDuration; - this.operationTaskExecutor = operationTaskExecutor; + this.operationTaskScheduler = operationTaskScheduler; } @Override @@ -72,6 +72,6 @@ protected void notifyFinished() { } catch (Exception e) { log().error("{} Couldn't finish operation task {}", logPrefix(), task().id()); } - operationTaskExecutor.releaseTask(task()); + operationTaskScheduler.releaseTask(task()); } } diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskExecutor.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskScheduler.java similarity index 94% rename from lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskExecutor.java rename to lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskScheduler.java index 4e2744c2d1..021dcc2cff 100644 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskExecutor.java +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskScheduler.java @@ -19,9 +19,9 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; -public class OperationTaskExecutor { +public class OperationTaskScheduler { - private static final Logger LOG = LogManager.getLogger(OperationTaskExecutor.class); + private static final Logger LOG = LogManager.getLogger(OperationTaskScheduler.class); private final OperationTaskDao opTaskDao; private final OperationsExecutor operationsExecutor; @@ -39,10 +39,10 @@ public class OperationTaskExecutor { private final AtomicBoolean started = new AtomicBoolean(false); private volatile boolean disabled = false; - public OperationTaskExecutor(OperationTaskDao opTaskDao, OperationsExecutor operationsExecutor, - OperationTaskResolver resolver, Duration initialDelay, Duration executionDelay, - Storage storage, TaskMetricsProvider metricsProvider, String instanceId, - Duration leaseDuration, int batchSize, int maxRunningTasks) + public OperationTaskScheduler(OperationTaskDao opTaskDao, OperationsExecutor operationsExecutor, + OperationTaskResolver resolver, Duration initialDelay, Duration executionDelay, + Storage storage, TaskMetricsProvider metricsProvider, String instanceId, + Duration leaseDuration, int batchSize, int maxRunningTasks) { this.opTaskDao = opTaskDao; this.operationsExecutor = operationsExecutor; diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/OperationTaskDao.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/OperationTaskDao.java index 4ad8c09cbc..56dd5a76b7 100644 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/OperationTaskDao.java +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/OperationTaskDao.java @@ -33,4 +33,6 @@ List recaptureOldTasks(String ownerId, Duration leaseDuration, @N @Nullable OperationTask tryLockTask(Long taskId, String entityId, String ownerId, Duration leaseDuration, @Nullable TransactionHandle tx) throws SQLException; + + List getAll(@Nullable TransactionHandle tx) throws SQLException; } diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/OperationTaskDaoImpl.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/OperationTaskDaoImpl.java index c727a916ce..93cb23912b 100644 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/OperationTaskDaoImpl.java +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/OperationTaskDaoImpl.java @@ -27,6 +27,9 @@ INSERT INTO operation_task (name, entity_id, type, status, created_at, updated_a VALUES (?, ?, cast(? as task_type), cast(? as task_status), now(), now(), cast(? as jsonb), ?) RETURNING %s; """.formatted(FIELDS); + public static final String GET_ALL_QUERY = """ + SELECT %s FROM operation_task + """.formatted(FIELDS); //in first nested request we gather all tasks that either locked or free. //in second nested request we filter result of previous request to get only free tasks @@ -243,6 +246,21 @@ public OperationTask tryLockTask(Long taskId, String entityId, String ownerId, D }); } + @Override + public List getAll(@Nullable TransactionHandle tx) throws SQLException { + return DbOperation.execute(tx, storage, c -> { + try (PreparedStatement ps = c.prepareStatement(GET_ALL_QUERY)) { + try (var rs = ps.executeQuery()) { + var result = new ArrayList(rs.getFetchSize()); + while (rs.next()) { + result.add(readTask(rs)); + } + return result; + } + } + }); + } + private static String prepareUpdateQuery(OperationTask.Update update) { var sb = new StringBuilder("UPDATE operation_task SET "); if (update.status() != null) { diff --git a/lzy/long-running/src/test/java/ai/lzy/longrunning/task/OperationTaskExecutorTest.java b/lzy/long-running/src/test/java/ai/lzy/longrunning/task/OperationTaskSchedulerTest.java similarity index 71% rename from lzy/long-running/src/test/java/ai/lzy/longrunning/task/OperationTaskExecutorTest.java rename to lzy/long-running/src/test/java/ai/lzy/longrunning/task/OperationTaskSchedulerTest.java index 14bee02fbe..d2ad4cb048 100644 --- a/lzy/long-running/src/test/java/ai/lzy/longrunning/task/OperationTaskExecutorTest.java +++ b/lzy/long-running/src/test/java/ai/lzy/longrunning/task/OperationTaskSchedulerTest.java @@ -22,12 +22,14 @@ import java.time.Duration; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.locks.LockSupport; import java.util.function.Supplier; +import java.util.stream.Collectors; import static org.junit.Assert.*; -public class OperationTaskExecutorTest { +public class OperationTaskSchedulerTest { public static final String MOUNT_TASK_TYPE = "MOUNT"; public static final Duration SCHEDULER_DELAY = Duration.ofMinutes(5); @@ -44,12 +46,11 @@ public class OperationTaskExecutorTest { private StorageImpl storage; private OperationTaskDaoImpl taskDao; - private OperationTaskExecutor taskExecutor; + private OperationTaskScheduler taskScheduler; private OperationDaoImpl operationDao; private OperationsExecutor operationsExecutor; private DispatchingOperationTaskResolver taskResolver; - @Before public void setup() { storage = new StorageImpl(DatabaseTestUtils.preparePostgresConfig(db.getConnectionInfo()), @@ -58,7 +59,7 @@ public void setup() { operationDao = new OperationDaoImpl(storage); operationsExecutor = new OperationsExecutor(5, 10, () -> {}, e -> false); taskResolver = new DispatchingOperationTaskResolver(List.of()); - taskExecutor = new OperationTaskExecutor(taskDao, operationsExecutor, taskResolver, Duration.ZERO, + taskScheduler = new OperationTaskScheduler(taskDao, operationsExecutor, taskResolver, Duration.ZERO, SCHEDULER_DELAY, storage, new StubMetricsProvider(), WORKER_ID , LEASE_DURATION, BATCH_SIZE, MAX_RUNNING_TASKS); @@ -68,27 +69,27 @@ SCHEDULER_DELAY, storage, new StubMetricsProvider(), WORKER_ID , LEASE_DURATION, public void teardown() { DatabaseTestUtils.cleanup(storage); storage.close(); - taskExecutor.shutdown(); + taskScheduler.shutdown(); } @Test - public void executorWorkflow() { - taskExecutor.start(); + public void schedulerWorkflow() { + taskScheduler.start(); } @Test - public void executorCannotBeStartedTwice() { - taskExecutor.start(); - assertThrows(IllegalStateException.class, () -> taskExecutor.start()); + public void schedulerCannotBeStartedTwice() { + taskScheduler.start(); + assertThrows(IllegalStateException.class, () -> taskScheduler.start()); } @Test - public void executorShouldWork() throws SQLException { + public void schedulerShouldWork() throws SQLException { taskResolver.addResolver(resolver(MOUNT_TASK_TYPE, () -> OperationRunnerBase.StepResult.FINISH, true)); var op = createOperation(); - var task = taskExecutor.saveTask(OperationTask.createPending("Test", "foo", + var task = taskScheduler.saveTask(OperationTask.createPending("Test", "foo", MOUNT_TASK_TYPE, Map.of(), op.id()), null); - taskExecutor.start(); + taskScheduler.start(); op = waitForOperation(op.id()); assertNull(op.error()); task = taskDao.get(task.id(), null); @@ -101,12 +102,12 @@ public void executorShouldWork() throws SQLException { } @Test - public void executorCanFail() throws SQLException { + public void schedulerCanFail() throws SQLException { taskResolver.addResolver(resolver(MOUNT_TASK_TYPE, () -> OperationRunnerBase.StepResult.FINISH, false)); var op = createOperation(); - var task = taskExecutor.saveTask(OperationTask.createPending("Test", "foo", + var task = taskScheduler.saveTask(OperationTask.createPending("Test", "foo", MOUNT_TASK_TYPE, Map.of(), op.id()), null); - taskExecutor.start(); + taskScheduler.start(); op = waitForOperation(op.id()); assertNotNull(op.error()); task = taskDao.get(task.id(), null); @@ -118,6 +119,30 @@ public void executorCanFail() throws SQLException { assertEquals(WORKER_ID, task.workerId()); } + @Test + public void schedulerWillLoadOnlyOneBatch() throws SQLException, InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + taskResolver.addResolver(resolver(MOUNT_TASK_TYPE, () -> { + latch.countDown(); + return OperationRunnerBase.StepResult.FINISH; + }, true)); + for (int i = 0; i < MAX_RUNNING_TASKS; i++) { + var op = createOperation(); + var task = taskScheduler.saveTask(OperationTask.createPending("Test", "foo" + i, + MOUNT_TASK_TYPE, Map.of(), op.id()), null); + } + taskScheduler.start(); + latch.await(); + var tasks = taskDao.getAll(null); + var tasksByPendingStatus = tasks.stream() + .collect(Collectors.partitioningBy(x -> x.status() == OperationTask.Status.PENDING)); + var pendingTasks = tasksByPendingStatus.get(true); + var runningTasks = tasksByPendingStatus.get(false); + assertEquals(MAX_RUNNING_TASKS - BATCH_SIZE, pendingTasks.size()); + assertEquals(BATCH_SIZE, runningTasks.size()); + + } + private Operation waitForOperation(String opId) throws SQLException { while (true) { var op = operationDao.get(opId, null); @@ -150,7 +175,7 @@ public String type() { public Result resolve(OperationTask task, @Nullable TransactionHandle tx) { return Result.success(new TestAction(action, completeOperation, task, taskDao, SCHEDULER_DELAY, task.operationId(), "Test action", storage, operationDao, - operationsExecutor, taskExecutor)); + operationsExecutor, taskScheduler)); } }; } diff --git a/lzy/long-running/src/test/java/ai/lzy/longrunning/task/TestAction.java b/lzy/long-running/src/test/java/ai/lzy/longrunning/task/TestAction.java index 2cd4419f96..38377744ae 100644 --- a/lzy/long-running/src/test/java/ai/lzy/longrunning/task/TestAction.java +++ b/lzy/long-running/src/test/java/ai/lzy/longrunning/task/TestAction.java @@ -22,10 +22,10 @@ public TestAction(Supplier action, boolean completeOperation, Operat OperationTaskDao operationTaskDao, Duration leaseDuration, String opId, String desc, Storage storage, OperationDao operationsDao, OperationsExecutor executor, - OperationTaskExecutor operationTaskExecutor) + OperationTaskScheduler operationTaskScheduler) { super(operationTask, operationTaskDao, leaseDuration, opId, desc, storage, operationsDao, executor, - operationTaskExecutor); + operationTaskScheduler); this.action = action; this.completeOperation = completeOperation; } From 5d0a2b28d5491265e793c59713b502012d76bc5b Mon Sep 17 00:00:00 2001 From: Hoxo <29162640+Hoxo@users.noreply.github.com> Date: Mon, 31 Jul 2023 14:51:24 +0300 Subject: [PATCH 8/9] few more tests --- .../task/OperationTaskScheduler.java | 43 +++++----- .../task/OperationTaskSchedulerTest.java | 81 ++++++++++++++++++- 2 files changed, 102 insertions(+), 22 deletions(-) diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskScheduler.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskScheduler.java index 021dcc2cff..83124bd6d1 100644 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskScheduler.java +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskScheduler.java @@ -52,7 +52,8 @@ public OperationTaskScheduler(OperationTaskDao opTaskDao, OperationsExecutor ope this.storage = storage; this.leaseDuration = leaseDuration; this.batchSize = batchSize; - this.scheduler = Executors.newSingleThreadScheduledExecutor(); //it's important to have only one thread + //it's important to have only one thread to execute all command subsequently + this.scheduler = Executors.newSingleThreadScheduledExecutor(); this.metricsProvider = metricsProvider; this.instanceId = instanceId; this.runningTaskQuota.set(maxRunningTasks); @@ -147,27 +148,29 @@ private OpTaskAwareAction resolveTask(OperationTask operationTask, TransactionHa } private void restoreTasks() { - try { - var actionsToRun = new ArrayList(); - //allow over-quoting to complete unfinished tasks first - DbHelper.withRetries(LOG, () -> { - try (var tx = TransactionHandle.create(storage)) { - var tasks = opTaskDao.recaptureOldTasks(instanceId, leaseDuration, tx); - for (OperationTask operationTask : tasks) { - var taskAwareAction = resolveTask(operationTask, tx); - if (taskAwareAction != null) { - actionsToRun.add(taskAwareAction); + scheduler.schedule(() -> { + try { + var actionsToRun = new ArrayList(); + //allow over-quoting to complete unfinished tasks first + DbHelper.withRetries(LOG, () -> { + try (var tx = TransactionHandle.create(storage)) { + var tasks = opTaskDao.recaptureOldTasks(instanceId, leaseDuration, tx); + for (OperationTask operationTask : tasks) { + var taskAwareAction = resolveTask(operationTask, tx); + if (taskAwareAction != null) { + actionsToRun.add(taskAwareAction); + } } + tx.commit(); } - tx.commit(); - } - }); - acquireTasks(actionsToRun.size()); - actionsToRun.forEach(operationsExecutor::startNew); - } catch (Exception e) { - LOG.error("Got exception while restoring tasks", e); - metricsProvider.schedulerErrors().inc(); - } + }); + acquireTasks(actionsToRun.size()); + actionsToRun.forEach(operationsExecutor::startNew); + } catch (Exception e) { + LOG.error("Got exception while restoring tasks", e); + metricsProvider.schedulerErrors().inc(); + } + }, 0, TimeUnit.MILLISECONDS); } private OperationTask setStatus(OperationTask operationTask, OperationTask.Status status, TransactionHandle tx) diff --git a/lzy/long-running/src/test/java/ai/lzy/longrunning/task/OperationTaskSchedulerTest.java b/lzy/long-running/src/test/java/ai/lzy/longrunning/task/OperationTaskSchedulerTest.java index d2ad4cb048..18957f3b64 100644 --- a/lzy/long-running/src/test/java/ai/lzy/longrunning/task/OperationTaskSchedulerTest.java +++ b/lzy/long-running/src/test/java/ai/lzy/longrunning/task/OperationTaskSchedulerTest.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; import java.util.concurrent.locks.LockSupport; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -90,7 +91,7 @@ public void schedulerShouldWork() throws SQLException { var task = taskScheduler.saveTask(OperationTask.createPending("Test", "foo", MOUNT_TASK_TYPE, Map.of(), op.id()), null); taskScheduler.start(); - op = waitForOperation(op.id()); + op = waitForOperation(op); assertNull(op.error()); task = taskDao.get(task.id(), null); assertNotNull(task); @@ -108,7 +109,7 @@ public void schedulerCanFail() throws SQLException { var task = taskScheduler.saveTask(OperationTask.createPending("Test", "foo", MOUNT_TASK_TYPE, Map.of(), op.id()), null); taskScheduler.start(); - op = waitForOperation(op.id()); + op = waitForOperation(op); assertNotNull(op.error()); task = taskDao.get(task.id(), null); assertNotNull(task); @@ -140,7 +141,83 @@ public void schedulerWillLoadOnlyOneBatch() throws SQLException, InterruptedExce var runningTasks = tasksByPendingStatus.get(false); assertEquals(MAX_RUNNING_TASKS - BATCH_SIZE, pendingTasks.size()); assertEquals(BATCH_SIZE, runningTasks.size()); + } + + @Test + public void schedulerCanScheduleImmediateTask() throws SQLException, ExecutionException, InterruptedException { + taskResolver.addResolver(resolver(MOUNT_TASK_TYPE, () -> OperationRunnerBase.StepResult.FINISH, true)); + taskScheduler.start(); + + var op = createOperation(); + var task = taskScheduler.saveTask(OperationTask.createPending("Test", "foo", + MOUNT_TASK_TYPE, Map.of(), op.id()), null); + var scheduledFuture = taskScheduler.startImmediately(task); + var scheduled = scheduledFuture.get(); + assertTrue(scheduled); + + waitForOperation(op); + task = taskDao.get(task.id(), null); + assertNotNull(task); + assertEquals(OperationTask.Status.FINISHED, task.status()); + } + + @Test + public void schedulerCannotScheduleTwoImmediateTaskWithSameEntityId() + throws SQLException, ExecutionException, InterruptedException + { + CountDownLatch latch = new CountDownLatch(1); + taskResolver.addResolver(resolver(MOUNT_TASK_TYPE, () -> { + try { + latch.await(); + } catch (InterruptedException e) {} + return OperationRunnerBase.StepResult.FINISH; + }, true)); + taskScheduler.start(); + + var op1 = createOperation(); + var task1 = taskScheduler.saveTask(OperationTask.createPending("Test", "foo", + MOUNT_TASK_TYPE, Map.of(), op1.id()), null); + var op2 = createOperation(); + var task2 = taskScheduler.saveTask(OperationTask.createPending("Test", "foo", + MOUNT_TASK_TYPE, Map.of(), op2.id()), null); + var task1Future = taskScheduler.startImmediately(task1); + var task2Future = taskScheduler.startImmediately(task2); + assertTrue(task1Future.get()); + assertFalse(task2Future.get()); + + task1 = taskDao.get(task1.id(), null); + task2 = taskDao.get(task2.id(), null); + assertNotNull(task1); + assertNotNull(task2); + assertEquals(OperationTask.Status.RUNNING, task1.status()); + assertEquals(WORKER_ID, task1.workerId()); + assertEquals(OperationTask.Status.PENDING, task2.status()); + assertNull(task2.workerId()); + latch.countDown(); + } + + @Test + public void schedulerShouldRestartOldTasksOnStart() throws SQLException { + taskResolver.addResolver(resolver(MOUNT_TASK_TYPE, () -> OperationRunnerBase.StepResult.FINISH, true)); + var op1 = createOperation(); + var task1 = taskScheduler.saveTask(OperationTask.createPending("Test", "foo", + MOUNT_TASK_TYPE, Map.of(), op1.id()), null); + task1 = taskDao.tryLockTask(task1.id(), task1.entityId(), WORKER_ID, LEASE_DURATION, null); + assertNotNull(task1); + assertEquals(OperationTask.Status.RUNNING, task1.status()); + assertEquals(WORKER_ID, task1.workerId()); + + taskScheduler.start(); + waitForOperation(op1); + + task1 = taskDao.get(task1.id(), null); + assertNotNull(task1); + assertEquals(OperationTask.Status.FINISHED, task1.status()); + assertEquals(WORKER_ID, task1.workerId()); + } + private Operation waitForOperation(Operation op) throws SQLException { + return waitForOperation(op.id()); } private Operation waitForOperation(String opId) throws SQLException { From 4c9049c3fe4549e1ce10c8b08d5e9f84a757ef61 Mon Sep 17 00:00:00 2001 From: Hoxo <29162640+Hoxo@users.noreply.github.com> Date: Mon, 31 Jul 2023 14:55:01 +0300 Subject: [PATCH 9/9] example --- .../alloc/MountDynamicDiskAction.java | 18 ++-- .../task/MountDynamicDiskResolver.java | 88 +++++++++++++++++++ 2 files changed, 101 insertions(+), 5 deletions(-) create mode 100644 lzy/allocator/src/main/java/ai/lzy/allocator/task/MountDynamicDiskResolver.java diff --git a/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/MountDynamicDiskAction.java b/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/MountDynamicDiskAction.java index 97378fd0bd..e4a7b3d16e 100644 --- a/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/MountDynamicDiskAction.java +++ b/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/MountDynamicDiskAction.java @@ -7,7 +7,10 @@ import ai.lzy.allocator.volume.VolumeManager; import ai.lzy.logs.LogContextKey; import ai.lzy.longrunning.Operation; -import ai.lzy.longrunning.OperationRunnerBase; +import ai.lzy.longrunning.task.OpTaskAwareAction; +import ai.lzy.longrunning.task.OperationTask; +import ai.lzy.longrunning.task.OperationTaskScheduler; +import ai.lzy.longrunning.task.dao.OperationTaskDao; import ai.lzy.model.db.TransactionHandle; import ai.lzy.v1.VmAllocatorApi; import com.google.protobuf.Any; @@ -17,6 +20,7 @@ import jakarta.annotation.Nullable; import java.sql.SQLException; +import java.time.Duration; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -25,7 +29,7 @@ import static ai.lzy.model.db.DbHelper.withRetries; -public final class MountDynamicDiskAction extends OperationRunnerBase { +public final class MountDynamicDiskAction extends OpTaskAwareAction { private final AllocationContext allocationContext; private final VolumeManager volumeManager; private final MountHolderManager mountHolderManager; @@ -42,9 +46,13 @@ public final class MountDynamicDiskAction extends OperationRunnerBase { private Long nextId; private boolean mountPodsDeleted; - public MountDynamicDiskAction(Vm vm, DynamicMount dynamicMount, AllocationContext allocationContext) { - super(dynamicMount.mountOperationId(), String.format("Mount %s to VM %s", dynamicMount.mountName(), vm.vmId()), - allocationContext.storage(), allocationContext.operationsDao(), allocationContext.executor()); + public MountDynamicDiskAction(Vm vm, DynamicMount dynamicMount, AllocationContext allocationContext, + OperationTask operationTask, OperationTaskDao operationTaskDao, + Duration leaseDuration, OperationTaskScheduler operationTaskScheduler) + { + super(operationTask, operationTaskDao, leaseDuration, dynamicMount.mountOperationId(), + String.format("Mount %s to VM %s", dynamicMount.mountName(), vm.vmId()), allocationContext.storage(), + allocationContext.operationsDao(), allocationContext.executor(), operationTaskScheduler); this.dynamicMount = dynamicMount; this.vm = vm; this.allocationContext = allocationContext; diff --git a/lzy/allocator/src/main/java/ai/lzy/allocator/task/MountDynamicDiskResolver.java b/lzy/allocator/src/main/java/ai/lzy/allocator/task/MountDynamicDiskResolver.java new file mode 100644 index 0000000000..1aa0c02b46 --- /dev/null +++ b/lzy/allocator/src/main/java/ai/lzy/allocator/task/MountDynamicDiskResolver.java @@ -0,0 +1,88 @@ +package ai.lzy.allocator.task; + +import ai.lzy.allocator.alloc.AllocationContext; +import ai.lzy.allocator.alloc.MountDynamicDiskAction; +import ai.lzy.allocator.alloc.dao.DynamicMountDao; +import ai.lzy.allocator.alloc.dao.VmDao; +import ai.lzy.allocator.model.DynamicMount; +import ai.lzy.allocator.model.Vm; +import ai.lzy.longrunning.task.OperationTask; +import ai.lzy.longrunning.task.OperationTaskScheduler; +import ai.lzy.longrunning.task.ResolverUtils; +import ai.lzy.longrunning.task.TypedOperationTaskResolver; +import ai.lzy.longrunning.task.dao.OperationTaskDao; +import ai.lzy.model.db.TransactionHandle; +import jakarta.annotation.Nullable; +import jakarta.inject.Singleton; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.sql.SQLException; +import java.time.Duration; + +@Singleton +public class MountDynamicDiskResolver implements TypedOperationTaskResolver { + private static final Logger LOG = LogManager.getLogger(MountDynamicDiskResolver.class); + + private static final String TYPE = "MOUNT"; + public static final String VM_ID_FIELD = "vm_id"; + public static final String DYNAMIC_MOUNT_ID_FIELD = "dynamic_mount_id"; + + private final VmDao vmDao; + private final DynamicMountDao dynamicMountDao; + private final AllocationContext allocationContext; + private final OperationTaskDao operationTaskDao; + private final OperationTaskScheduler taskScheduler; //todo circular dependency + private final Duration leaseDuration; + + public MountDynamicDiskResolver(VmDao vmDao, DynamicMountDao dynamicMountDao, AllocationContext allocationContext, + OperationTaskDao operationTaskDao, OperationTaskScheduler taskScheduler, + Duration leaseDuration) + //todo mark duration with a qualifier + { + this.vmDao = vmDao; + this.dynamicMountDao = dynamicMountDao; + this.allocationContext = allocationContext; + this.operationTaskDao = operationTaskDao; + this.taskScheduler = taskScheduler; + this.leaseDuration = leaseDuration; + } + + @Override + public Result resolve(OperationTask opTask, @Nullable TransactionHandle tx) throws SQLException { + var vmId = ResolverUtils.readString(opTask.metadata(), VM_ID_FIELD); + if (vmId == null) { + LOG.error("{} field is not present in task {} metadata", VM_ID_FIELD, opTask.id()); + return Result.BAD_STATE; + } + var dynamicMountId = ResolverUtils.readString(opTask.metadata(), DYNAMIC_MOUNT_ID_FIELD); + if (dynamicMountId == null) { + LOG.error("{} field is not present in task {} metadata", DYNAMIC_MOUNT_ID_FIELD, opTask.id()); + return Result.BAD_STATE; + } + var vm = vmDao.get(vmId, tx); + if (vm == null) { + LOG.error("VM {} is not present for task", vmId); + return Result.STALE; + } else if (vm.status() != Vm.Status.RUNNING) { + LOG.error("VM {} is in wrong status: {}", vmId, vm.status()); + return Result.STALE; + } + var dynamicMount = dynamicMountDao.get(dynamicMountId, false, tx); + if (dynamicMount == null) { + LOG.error("Dynamic mount {} is not present for task", dynamicMountId); + return Result.STALE; + } else if (dynamicMount.state() != DynamicMount.State.PENDING) { + LOG.error("Dynamic mount {} is in wrong status: {}", dynamicMount.id(), dynamicMount.state()); + return Result.STALE; + } + return Result.success(new MountDynamicDiskAction(vm, dynamicMount, allocationContext, opTask, operationTaskDao, + leaseDuration, taskScheduler)); + } + + @Override + public String type() { + return TYPE; + } + +}