diff --git a/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/DeleteSessionAction.java b/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/DeleteSessionAction.java index 1e7ed69067..e77d8d09d4 100644 --- a/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/DeleteSessionAction.java +++ b/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/DeleteSessionAction.java @@ -46,14 +46,6 @@ protected boolean isInjectedError(Error e) { return e instanceof InjectedFailures.TerminateException; } - @Override - protected void notifyExpired() { - } - - @Override - protected void notifyFinished() { - } - @Override protected void onExpired(TransactionHandle tx) { throw new RuntimeException("Unexpected, sessionId: '%s', op: '%s'".formatted(sessionId, id())); diff --git a/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/MountDynamicDiskAction.java b/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/MountDynamicDiskAction.java index 97378fd0bd..e4a7b3d16e 100644 --- a/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/MountDynamicDiskAction.java +++ b/lzy/allocator/src/main/java/ai/lzy/allocator/alloc/MountDynamicDiskAction.java @@ -7,7 +7,10 @@ import ai.lzy.allocator.volume.VolumeManager; import ai.lzy.logs.LogContextKey; import ai.lzy.longrunning.Operation; -import ai.lzy.longrunning.OperationRunnerBase; +import ai.lzy.longrunning.task.OpTaskAwareAction; +import ai.lzy.longrunning.task.OperationTask; +import ai.lzy.longrunning.task.OperationTaskScheduler; +import ai.lzy.longrunning.task.dao.OperationTaskDao; import ai.lzy.model.db.TransactionHandle; import ai.lzy.v1.VmAllocatorApi; import com.google.protobuf.Any; @@ -17,6 +20,7 @@ import jakarta.annotation.Nullable; import java.sql.SQLException; +import java.time.Duration; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -25,7 +29,7 @@ import static ai.lzy.model.db.DbHelper.withRetries; -public final class MountDynamicDiskAction extends OperationRunnerBase { +public final class MountDynamicDiskAction extends OpTaskAwareAction { private final AllocationContext allocationContext; private final VolumeManager volumeManager; private final MountHolderManager mountHolderManager; @@ -42,9 +46,13 @@ public final class MountDynamicDiskAction extends OperationRunnerBase { private Long nextId; private boolean mountPodsDeleted; - public MountDynamicDiskAction(Vm vm, DynamicMount dynamicMount, AllocationContext allocationContext) { - super(dynamicMount.mountOperationId(), String.format("Mount %s to VM %s", dynamicMount.mountName(), vm.vmId()), - allocationContext.storage(), allocationContext.operationsDao(), allocationContext.executor()); + public MountDynamicDiskAction(Vm vm, DynamicMount dynamicMount, AllocationContext allocationContext, + OperationTask operationTask, OperationTaskDao operationTaskDao, + Duration leaseDuration, OperationTaskScheduler operationTaskScheduler) + { + super(operationTask, operationTaskDao, leaseDuration, dynamicMount.mountOperationId(), + String.format("Mount %s to VM %s", dynamicMount.mountName(), vm.vmId()), allocationContext.storage(), + allocationContext.operationsDao(), allocationContext.executor(), operationTaskScheduler); this.dynamicMount = dynamicMount; this.vm = vm; this.allocationContext = allocationContext; diff --git a/lzy/allocator/src/main/java/ai/lzy/allocator/task/MountDynamicDiskResolver.java b/lzy/allocator/src/main/java/ai/lzy/allocator/task/MountDynamicDiskResolver.java new file mode 100644 index 0000000000..1aa0c02b46 --- /dev/null +++ b/lzy/allocator/src/main/java/ai/lzy/allocator/task/MountDynamicDiskResolver.java @@ -0,0 +1,88 @@ +package ai.lzy.allocator.task; + +import ai.lzy.allocator.alloc.AllocationContext; +import ai.lzy.allocator.alloc.MountDynamicDiskAction; +import ai.lzy.allocator.alloc.dao.DynamicMountDao; +import ai.lzy.allocator.alloc.dao.VmDao; +import ai.lzy.allocator.model.DynamicMount; +import ai.lzy.allocator.model.Vm; +import ai.lzy.longrunning.task.OperationTask; +import ai.lzy.longrunning.task.OperationTaskScheduler; +import ai.lzy.longrunning.task.ResolverUtils; +import ai.lzy.longrunning.task.TypedOperationTaskResolver; +import ai.lzy.longrunning.task.dao.OperationTaskDao; +import ai.lzy.model.db.TransactionHandle; +import jakarta.annotation.Nullable; +import jakarta.inject.Singleton; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.sql.SQLException; +import java.time.Duration; + +@Singleton +public class MountDynamicDiskResolver implements TypedOperationTaskResolver { + private static final Logger LOG = LogManager.getLogger(MountDynamicDiskResolver.class); + + private static final String TYPE = "MOUNT"; + public static final String VM_ID_FIELD = "vm_id"; + public static final String DYNAMIC_MOUNT_ID_FIELD = "dynamic_mount_id"; + + private final VmDao vmDao; + private final DynamicMountDao dynamicMountDao; + private final AllocationContext allocationContext; + private final OperationTaskDao operationTaskDao; + private final OperationTaskScheduler taskScheduler; //todo circular dependency + private final Duration leaseDuration; + + public MountDynamicDiskResolver(VmDao vmDao, DynamicMountDao dynamicMountDao, AllocationContext allocationContext, + OperationTaskDao operationTaskDao, OperationTaskScheduler taskScheduler, + Duration leaseDuration) + //todo mark duration with a qualifier + { + this.vmDao = vmDao; + this.dynamicMountDao = dynamicMountDao; + this.allocationContext = allocationContext; + this.operationTaskDao = operationTaskDao; + this.taskScheduler = taskScheduler; + this.leaseDuration = leaseDuration; + } + + @Override + public Result resolve(OperationTask opTask, @Nullable TransactionHandle tx) throws SQLException { + var vmId = ResolverUtils.readString(opTask.metadata(), VM_ID_FIELD); + if (vmId == null) { + LOG.error("{} field is not present in task {} metadata", VM_ID_FIELD, opTask.id()); + return Result.BAD_STATE; + } + var dynamicMountId = ResolverUtils.readString(opTask.metadata(), DYNAMIC_MOUNT_ID_FIELD); + if (dynamicMountId == null) { + LOG.error("{} field is not present in task {} metadata", DYNAMIC_MOUNT_ID_FIELD, opTask.id()); + return Result.BAD_STATE; + } + var vm = vmDao.get(vmId, tx); + if (vm == null) { + LOG.error("VM {} is not present for task", vmId); + return Result.STALE; + } else if (vm.status() != Vm.Status.RUNNING) { + LOG.error("VM {} is in wrong status: {}", vmId, vm.status()); + return Result.STALE; + } + var dynamicMount = dynamicMountDao.get(dynamicMountId, false, tx); + if (dynamicMount == null) { + LOG.error("Dynamic mount {} is not present for task", dynamicMountId); + return Result.STALE; + } else if (dynamicMount.state() != DynamicMount.State.PENDING) { + LOG.error("Dynamic mount {} is in wrong status: {}", dynamicMount.id(), dynamicMount.state()); + return Result.STALE; + } + return Result.success(new MountDynamicDiskAction(vm, dynamicMount, allocationContext, opTask, operationTaskDao, + leaseDuration, taskScheduler)); + } + + @Override + public String type() { + return TYPE; + } + +} diff --git a/lzy/allocator/src/main/resources/db/allocator/migrations/V10__operation_task.sql b/lzy/allocator/src/main/resources/db/allocator/migrations/V10__operation_task.sql new file mode 100644 index 0000000000..8149674abc --- /dev/null +++ b/lzy/allocator/src/main/resources/db/allocator/migrations/V10__operation_task.sql @@ -0,0 +1,21 @@ +CREATE TYPE task_status AS ENUM ('PENDING', 'RUNNING', 'FAILED', 'FINISHED', 'STALE'); + +CREATE TYPE task_type AS ENUM ('UNMOUNT', 'MOUNT'); + +CREATE TABLE IF NOT EXISTS operation_task( + id BIGSERIAL NOT NULL, + name TEXT NOT NULL, + entity_id TEXT NOT NULL, + type task_type NOT NULL, + status task_status NOT NULL, + created_at TIMESTAMP NOT NULL, + updated_at TIMESTAMP NOT NULL, + metadata JSONB NOT NULL, + operation_id TEXT, + worker_id TEXT, + lease_till TIMESTAMP, + PRIMARY KEY (id), + FOREIGN KEY (operation_id) REFERENCES operation(id) +); + +CREATE INDEX IF NOT EXISTS task_status_entity_id_idx ON operation_task(status, entity_id, id); diff --git a/lzy/long-running/pom.xml b/lzy/long-running/pom.xml index 7930f9992d..634dbb4d59 100644 --- a/lzy/long-running/pom.xml +++ b/lzy/long-running/pom.xml @@ -45,6 +45,16 @@ junit test + + io.zonky.test + embedded-postgres + test + + + org.mockito + mockito-core + test + diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/OperationRunnerBase.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/OperationRunnerBase.java index 434392b044..61ef5bcda7 100644 --- a/lzy/long-running/src/main/java/ai/lzy/longrunning/OperationRunnerBase.java +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/OperationRunnerBase.java @@ -32,6 +32,7 @@ public abstract class OperationRunnerBase extends ContextAwareTask { private final OperationDao operationsDao; private final OperationsExecutor executor; private Operation op; + private volatile boolean failed = false; protected OperationRunnerBase(String id, String descr, Storage storage, OperationDao operationsDao, OperationsExecutor executor) @@ -55,6 +56,7 @@ protected final void execute() { } for (var step : steps()) { + beforeStep(); final var stepResult = step.get(); switch (stepResult.code()) { case ALREADY_DONE -> { } @@ -83,6 +85,7 @@ protected final void execute() { } } } catch (Throwable e) { + setFailed(); notifyFinished(); if (e instanceof Error err && isInjectedError(err)) { log.error("{} Terminated by InjectedFailure exception: {}", logPrefix, e.getMessage()); @@ -98,6 +101,18 @@ protected final void execute() { } } + protected void setFailed() { + failed = true; + } + + protected boolean isFailed() { + return failed; + } + + protected void beforeStep() { + + } + protected Map prepareLogContext() { var ctx = super.prepareLogContext(); ctx.put(LogContextKey.OPERATION_ID, id); @@ -276,6 +291,7 @@ protected void notifyFinished() { } protected final void failOperation(Status status, @Nullable TransactionHandle tx) throws SQLException { + setFailed(); operationsDao.fail(id, toProto(status), tx); } diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/DispatchingOperationTaskResolver.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/DispatchingOperationTaskResolver.java new file mode 100644 index 0000000000..8656b72de2 --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/DispatchingOperationTaskResolver.java @@ -0,0 +1,56 @@ +package ai.lzy.longrunning.task; + +import ai.lzy.model.db.TransactionHandle; +import com.google.common.annotations.VisibleForTesting; +import jakarta.annotation.Nullable; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.sql.SQLException; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class DispatchingOperationTaskResolver implements OperationTaskResolver { + + private static final Logger LOG = LogManager.getLogger(DispatchingOperationTaskResolver.class); + private final Map resolvers; + + public DispatchingOperationTaskResolver(List resolvers) { + this.resolvers = generateResolversMap(resolvers); + } + + private static Map generateResolversMap( + List resolvers) + { + var types = new HashSet(); + resolvers.forEach(r -> { + if (!types.add(r.type())) { + throw new IllegalStateException("Duplicate resolver for type " + r.type()); + } + }); + return resolvers.stream() + .collect(Collectors.toMap(TypedOperationTaskResolver::type, r -> r)); + } + + @VisibleForTesting + void addResolver(TypedOperationTaskResolver resolver) { + resolvers.put(resolver.type(), resolver); + } + + @Override + public Result resolve(OperationTask operationTask, @Nullable TransactionHandle tx) throws SQLException { + var resolver = resolvers.get(operationTask.type()); + if (resolver == null) { + LOG.error("No resolver for task type {}. Task: {}", operationTask.type(), operationTask); + return Result.UNKNOWN_TASK; + } + try { + return resolver.resolve(operationTask, tx); + } catch (Exception e) { + LOG.error("Error while resolving task {}", operationTask.id(), e); + return Result.resolveError(e); + } + } +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OpTaskAwareAction.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OpTaskAwareAction.java new file mode 100644 index 0000000000..d202521fd2 --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OpTaskAwareAction.java @@ -0,0 +1,77 @@ +package ai.lzy.longrunning.task; + +import ai.lzy.longrunning.OperationRunnerBase; +import ai.lzy.longrunning.OperationsExecutor; +import ai.lzy.longrunning.dao.OperationDao; +import ai.lzy.longrunning.task.dao.OperationTaskDao; +import ai.lzy.model.db.Storage; + +import java.time.Duration; +import java.util.Map; + +import static ai.lzy.model.db.DbHelper.withRetries; + +public abstract class OpTaskAwareAction extends OperationRunnerBase { + private final OperationTaskScheduler operationTaskScheduler; + private final OperationTaskDao operationTaskDao; + private final Duration leaseDuration; + private OperationTask operationTask; + + public OpTaskAwareAction(OperationTask operationTask, OperationTaskDao operationTaskDao, Duration leaseDuration, + String opId, String desc, Storage storage, OperationDao operationsDao, + OperationsExecutor executor, OperationTaskScheduler operationTaskScheduler) + { + super(opId, desc, storage, operationsDao, executor); + this.operationTask = operationTask; + this.operationTaskDao = operationTaskDao; + this.leaseDuration = leaseDuration; + this.operationTaskScheduler = operationTaskScheduler; + } + + @Override + protected Map prepareLogContext() { + var ctx = super.prepareLogContext(); + ctx.put("task_id", String.valueOf(operationTask.id())); + ctx.put("task_type", operationTask.type()); + ctx.put("task_name", operationTask.name()); + ctx.put("task_entity_id", operationTask.entityId()); + return ctx; + } + + protected OperationTask task() { + return operationTask; + } + + public void setTask(OperationTask operationTask) { + this.operationTask = operationTask; + } + + @Override + protected void beforeStep() { + super.beforeStep(); + try { + operationTask = withRetries(log(), () -> operationTaskDao.updateLease(operationTask.id(), leaseDuration, + null)); + } catch (Exception e) { + log().error("{} Couldn't update lease on task {}", logPrefix(), task().id()); + throw new RuntimeException(e); + } + } + + @Override + protected void notifyFinished() { + var builder = OperationTask.Update.builder(); + if (isFailed()) { + builder.status(OperationTask.Status.FAILED); + } else { + builder.status(OperationTask.Status.FINISHED); + } + try { + operationTask = withRetries(log(), () -> operationTaskDao.update(operationTask.id(), builder.build(), + null)); + } catch (Exception e) { + log().error("{} Couldn't finish operation task {}", logPrefix(), task().id()); + } + operationTaskScheduler.releaseTask(task()); + } +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTask.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTask.java new file mode 100644 index 0000000000..8e4ef3dbaf --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTask.java @@ -0,0 +1,77 @@ +package ai.lzy.longrunning.task; + +import jakarta.annotation.Nullable; + +import java.time.Instant; +import java.util.Map; + +public record OperationTask( + long id, + String name, + String entityId, + String type, + Status status, + Instant createdAt, + Instant updatedAt, + Map metadata, + @Nullable + String operationId, + @Nullable + String workerId, + @Nullable + Instant leaseTill +) { + public static OperationTask createPending(String name, String entityId, String type, Map metadata, + String operationId) + { + return new OperationTask(-1, name, entityId, type, Status.PENDING, Instant.now(), Instant.now(), + metadata, operationId, null, null); + } + + public enum Status { + PENDING, //just created and waiting to be executed + RUNNING, //acquired by one of the instance of app and executing + FINISHED, //successful finish + FAILED, //unrecoverable failure + STALE, //executed too late and not applicable anymore + } + + public record Update( + @Nullable Status status, + @Nullable Map metadata, + @Nullable String operationId + ) { + public boolean isEmpty() { + return status == null && metadata == null && operationId == null; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private Status status; + private Map metadata; + private String operationId; + + public Builder status(Status status) { + this.status = status; + return this; + } + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public Builder operationId(String operationId) { + this.operationId = operationId; + return this; + } + + public Update build() { + return new Update(status, metadata, operationId); + } + } + } +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskResolver.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskResolver.java new file mode 100644 index 0000000000..bc0208c555 --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskResolver.java @@ -0,0 +1,36 @@ +package ai.lzy.longrunning.task; + +import ai.lzy.model.db.TransactionHandle; +import jakarta.annotation.Nullable; + +import java.sql.SQLException; + +public interface OperationTaskResolver { + Result resolve(OperationTask operationTask, @Nullable TransactionHandle tx) throws SQLException; + + enum Status { + SUCCESS, + STALE, + BAD_STATE, + UNKNOWN_TASK, + RESOLVE_ERROR, + } + + record Result( + @Nullable OpTaskAwareAction action, + Status status, + @Nullable Exception exception + ) { + public static final Result STALE = new Result(null, Status.STALE, null); + public static final Result BAD_STATE = new Result(null, Status.BAD_STATE, null); + public static final Result UNKNOWN_TASK = new Result(null, Status.UNKNOWN_TASK, null); + + public static Result success(OpTaskAwareAction action) { + return new Result(action, Status.SUCCESS, null); + } + + public static Result resolveError(Exception e) { + return new Result(null, Status.RESOLVE_ERROR, e); + } + } +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskScheduler.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskScheduler.java new file mode 100644 index 0000000000..83124bd6d1 --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/OperationTaskScheduler.java @@ -0,0 +1,232 @@ +package ai.lzy.longrunning.task; + +import ai.lzy.longrunning.OperationsExecutor; +import ai.lzy.longrunning.task.dao.OperationTaskDao; +import ai.lzy.model.db.DbHelper; +import ai.lzy.model.db.Storage; +import ai.lzy.model.db.TransactionHandle; +import jakarta.annotation.Nullable; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.sql.SQLException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +public class OperationTaskScheduler { + + private static final Logger LOG = LogManager.getLogger(OperationTaskScheduler.class); + + private final OperationTaskDao opTaskDao; + private final OperationsExecutor operationsExecutor; + private final OperationTaskResolver resolver; + private final ScheduledExecutorService scheduler; + private final Duration initialDelay; + private final Duration executionDelay; + private final Storage storage; + private final TaskMetricsProvider metricsProvider; + private final String instanceId; + private final Duration leaseDuration; + private final int batchSize; + private final AtomicInteger runningTaskQuota = new AtomicInteger(); + + private final AtomicBoolean started = new AtomicBoolean(false); + private volatile boolean disabled = false; + + public OperationTaskScheduler(OperationTaskDao opTaskDao, OperationsExecutor operationsExecutor, + OperationTaskResolver resolver, Duration initialDelay, Duration executionDelay, + Storage storage, TaskMetricsProvider metricsProvider, String instanceId, + Duration leaseDuration, int batchSize, int maxRunningTasks) + { + this.opTaskDao = opTaskDao; + this.operationsExecutor = operationsExecutor; + this.resolver = resolver; + this.initialDelay = initialDelay; + this.executionDelay = executionDelay; + this.storage = storage; + this.leaseDuration = leaseDuration; + this.batchSize = batchSize; + //it's important to have only one thread to execute all command subsequently + this.scheduler = Executors.newSingleThreadScheduledExecutor(); + this.metricsProvider = metricsProvider; + this.instanceId = instanceId; + this.runningTaskQuota.set(maxRunningTasks); + } + + public void start() { + if (!started.compareAndSet(false, true)) { + throw new IllegalStateException("Task executor has already started!"); + } + restoreTasks(); + startMailLoop(); + } + + private void acquireTask() { + acquireTasks(1); + } + + private void acquireTasks(int count) { + var newQuota = runningTaskQuota.addAndGet(-count); + LOG.debug("Acquired {} tasks to run. Current quota: {}", count, newQuota); + metricsProvider.runningTasks().inc(count); + } + + public void releaseTask(OperationTask task) { + LOG.debug("Finishing task {}", task.id()); + runningTaskQuota.incrementAndGet(); + metricsProvider.runningTasks().dec(); + } + + private ScheduledFuture startMailLoop() { + return scheduler.scheduleWithFixedDelay(() -> { + try { + var actions = new ArrayList(); + if (!hasQuota()) { + LOG.info("Not enough quota to start new operation tasks"); + return; + } + var quota = runningTaskQuota.get(); + var toLoad = Math.min(batchSize, quota); + DbHelper.withRetries(LOG, () -> { + try (var tx = TransactionHandle.create(storage)) { + for (OperationTask operationTask : opTaskDao.lockPendingBatch(instanceId, leaseDuration, + toLoad, tx)) + { + if (disabled) { + return; + } + var taskAwareAction = resolveTask(operationTask, tx); + if (taskAwareAction != null) { + actions.add(taskAwareAction); + } + } + tx.commit(); + } + }); + acquireTasks(actions.size()); + actions.forEach(operationsExecutor::startNew); + } catch (Exception e) { + LOG.error("Got exception while scheduling task", e); + metricsProvider.schedulerErrors().inc(); + } + }, initialDelay.toMillis(), executionDelay.toMillis(), TimeUnit.MILLISECONDS); + } + + @Nullable + private OpTaskAwareAction resolveTask(OperationTask operationTask, TransactionHandle tx) throws SQLException { + var resolveResult = resolver.resolve(operationTask, tx); + if (resolveResult.status() != OperationTaskResolver.Status.SUCCESS) { + metricsProvider.schedulerResolveErrors(resolveResult.status()).inc(); + } + switch (resolveResult.status()) { + case SUCCESS -> { + var updatedTask = setStatus(operationTask, OperationTask.Status.RUNNING, tx); + var action = resolveResult.action(); + assert action != null; + action.setTask(updatedTask); + return action; + } + case STALE -> { + LOG.warn("Marking task {} as STALE", operationTask.id(), resolveResult.exception()); + setStatus(operationTask, OperationTask.Status.STALE, tx); + } + case BAD_STATE -> { + LOG.error("Marking task {} as FAILED", operationTask.id(), resolveResult.exception()); + setStatus(operationTask, OperationTask.Status.FAILED, tx); + } + case UNKNOWN_TASK, RESOLVE_ERROR -> { + LOG.warn("Couldn't resolve task {}", operationTask.id(), resolveResult.exception()); + } + } + return null; + } + + private void restoreTasks() { + scheduler.schedule(() -> { + try { + var actionsToRun = new ArrayList(); + //allow over-quoting to complete unfinished tasks first + DbHelper.withRetries(LOG, () -> { + try (var tx = TransactionHandle.create(storage)) { + var tasks = opTaskDao.recaptureOldTasks(instanceId, leaseDuration, tx); + for (OperationTask operationTask : tasks) { + var taskAwareAction = resolveTask(operationTask, tx); + if (taskAwareAction != null) { + actionsToRun.add(taskAwareAction); + } + } + tx.commit(); + } + }); + acquireTasks(actionsToRun.size()); + actionsToRun.forEach(operationsExecutor::startNew); + } catch (Exception e) { + LOG.error("Got exception while restoring tasks", e); + metricsProvider.schedulerErrors().inc(); + } + }, 0, TimeUnit.MILLISECONDS); + } + + private OperationTask setStatus(OperationTask operationTask, OperationTask.Status status, TransactionHandle tx) + throws SQLException + { + return opTaskDao.update(operationTask.id(), OperationTask.Update.builder() + .status(status) + .build(), tx); + } + + public OperationTask saveTask(OperationTask operationTask, @Nullable TransactionHandle tx) throws SQLException { + return opTaskDao.insert(operationTask, tx); + } + + public ScheduledFuture startImmediately(OperationTask opTask) { + //schedule runnable to start task immediately. + //if task is already captured by main loop, this runnable will exit + return scheduler.schedule(() -> { + try { + if (!hasQuota()) { + LOG.debug("Not enough quota to start immediate operation task {}", opTask.id()); + return false; + } + var action = DbHelper.withRetries(LOG, () -> { + try (var tx = TransactionHandle.create(storage)) { + var lockedTask = opTaskDao.tryLockTask(opTask.id(), + opTask.entityId(), instanceId, leaseDuration, + tx); + if (lockedTask == null) { + return null; + } + var taskAwareAction = resolveTask(lockedTask, tx); + tx.commit(); + return taskAwareAction; + } + }); + if (action == null) { + return false; + } + acquireTask(); + operationsExecutor.startNew(action); + return true; + } catch (Exception e) { + LOG.error("Got exception while scheduling task", e); + metricsProvider.schedulerErrors().inc(); + return false; + } + }, 0, TimeUnit.MILLISECONDS); + } + + private boolean hasQuota() { + return runningTaskQuota.get() > 0; + } + + public void shutdown() { + disabled = true; + scheduler.shutdown(); + } +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/ResolverUtils.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/ResolverUtils.java new file mode 100644 index 0000000000..a2424047be --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/ResolverUtils.java @@ -0,0 +1,19 @@ +package ai.lzy.longrunning.task; + +import jakarta.annotation.Nullable; + +import java.util.Map; + +public final class ResolverUtils { + private ResolverUtils() { + } + + @Nullable + public static String readString(Map meta, String key) { + var obj = meta.get(key); + if (obj instanceof String s) { + return s; + } + return null; + } +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskMetricsProvider.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskMetricsProvider.java new file mode 100644 index 0000000000..448ebb6b89 --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TaskMetricsProvider.java @@ -0,0 +1,10 @@ +package ai.lzy.longrunning.task; + +import io.prometheus.client.Gauge; + +public interface TaskMetricsProvider { + Gauge schedulerErrors(); + Gauge schedulerResolveErrors(OperationTaskResolver.Status status); + Gauge queueSize(); + Gauge runningTasks(); +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TypedOperationTaskResolver.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TypedOperationTaskResolver.java new file mode 100644 index 0000000000..4a4f22bd97 --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/TypedOperationTaskResolver.java @@ -0,0 +1,5 @@ +package ai.lzy.longrunning.task; + +public interface TypedOperationTaskResolver extends OperationTaskResolver { + String type(); +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/OperationTaskDao.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/OperationTaskDao.java new file mode 100644 index 0000000000..56dd5a76b7 --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/OperationTaskDao.java @@ -0,0 +1,38 @@ +package ai.lzy.longrunning.task.dao; + +import ai.lzy.longrunning.task.OperationTask; +import ai.lzy.model.db.TransactionHandle; +import jakarta.annotation.Nullable; + +import java.sql.SQLException; +import java.time.Duration; +import java.util.List; + +public interface OperationTaskDao { + + @Nullable + OperationTask get(long id, @Nullable TransactionHandle tx) throws SQLException; + + @Nullable + OperationTask update(long id, OperationTask.Update update, @Nullable TransactionHandle tx) throws SQLException; + + @Nullable + OperationTask updateLease(long id, Duration duration, @Nullable TransactionHandle tx) throws SQLException; + + @Nullable + OperationTask insert(OperationTask operationTask, @Nullable TransactionHandle tx) throws SQLException; + + void delete(long id, @Nullable TransactionHandle tx) throws SQLException; + + List lockPendingBatch(String ownerId, Duration leaseDuration, int batchSize, + @Nullable TransactionHandle tx) throws SQLException; + + List recaptureOldTasks(String ownerId, Duration leaseDuration, @Nullable TransactionHandle tx) + throws SQLException; + + @Nullable + OperationTask tryLockTask(Long taskId, String entityId, String ownerId, Duration leaseDuration, + @Nullable TransactionHandle tx) throws SQLException; + + List getAll(@Nullable TransactionHandle tx) throws SQLException; +} diff --git a/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/OperationTaskDaoImpl.java b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/OperationTaskDaoImpl.java new file mode 100644 index 0000000000..93cb23912b --- /dev/null +++ b/lzy/long-running/src/main/java/ai/lzy/longrunning/task/dao/OperationTaskDaoImpl.java @@ -0,0 +1,319 @@ +package ai.lzy.longrunning.task.dao; + +import ai.lzy.longrunning.task.OperationTask; +import ai.lzy.model.db.DbOperation; +import ai.lzy.model.db.Storage; +import ai.lzy.model.db.TransactionHandle; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import jakarta.annotation.Nullable; + +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public class OperationTaskDaoImpl implements OperationTaskDao { + + private static final String FIELDS = "id, name, entity_id, type, status, created_at, updated_at, metadata," + + " operation_id, worker_id, lease_till"; + public static final String SELECT_QUERY = "SELECT %s FROM operation_task WHERE id = ?".formatted(FIELDS); + public static final String INSERT_QUERY = """ + INSERT INTO operation_task (name, entity_id, type, status, created_at, updated_at, metadata, operation_id) + VALUES (?, ?, cast(? as task_type), cast(? as task_status), now(), now(), cast(? as jsonb), ?) + RETURNING %s; + """.formatted(FIELDS); + public static final String GET_ALL_QUERY = """ + SELECT %s FROM operation_task + """.formatted(FIELDS); + + //in first nested request we gather all tasks that either locked or free. + //in second nested request we filter result of previous request to get only free tasks + //and select only specific amount. + private static final String LOCK_PENDING_BATCH_QUERY = """ + UPDATE operation_task + SET status = 'RUNNING', worker_id = ?, updated_at = now(), lease_till = now() + cast(? as interval) + WHERE id IN ( + SELECT id + FROM operation_task + WHERE id IN ( + SELECT DISTINCT ON (entity_id) id + FROM operation_task + WHERE status IN ('PENDING', 'RUNNING') + ORDER BY entity_id, id + ) AND status = 'PENDING' + LIMIT ? + ) + RETURNING %s; + """.formatted(FIELDS); + + private static final String RECAPTURE_OLD_TASKS_QUERY = """ + UPDATE operation_task + SET updated_at = now(), lease_till = now() + cast(? as interval) + WHERE status = 'RUNNING' AND worker_id = ? + RETURNING %s; + """.formatted(FIELDS); + + //in first nested request we select all tasks by entity_id that are either locked or free and take the first one. + //in second nested request we filter result of previous request to get only pending task with specific id. + //if we get the task then we lock it and return it. + private static final String TRY_LOCK_TASK_QUERY = """ + UPDATE operation_task + SET status = 'RUNNING', worker_id = ?, updated_at = now(), lease_till = now() + cast(? as interval) + WHERE id IN ( + SELECT id + FROM operation_task + WHERE id IN ( + SELECT id + FROM operation_task + WHERE status IN ('PENDING', 'RUNNING') AND entity_id = ? + ORDER BY id + LIMIT 1 + ) AND status = 'PENDING' AND id = ? + ) + RETURNING %s; + """.formatted(FIELDS); + + public static final String DELETE_QUERY = "DELETE FROM operation_task WHERE id = ?"; + public static final String UPDATE_LEASE_QUERY = """ + UPDATE operation_task + SET lease_till = now() + cast(? as interval) + WHERE id = ? + RETURNING %s + """.formatted(FIELDS); + private static final TypeReference> MAP_TYPE_REFERENCE = new TypeReference<>() { }; + + + private final Storage storage; + private final ObjectMapper objectMapper; + + public OperationTaskDaoImpl(Storage storage, ObjectMapper objectMapper) { + this.storage = storage; + this.objectMapper = objectMapper; + } + + @Nullable + @Override + public OperationTask get(long id, @Nullable TransactionHandle tx) throws SQLException { + return DbOperation.execute(tx, storage, c -> { + try (PreparedStatement ps = c.prepareStatement(SELECT_QUERY)) { + ps.setLong(1, id); + try (ResultSet rs = ps.executeQuery()) { + if (rs.next()) { + return readTask(rs); + } + return null; + } + } + }); + } + + @Override + @Nullable + public OperationTask update(long id, OperationTask.Update update, @Nullable TransactionHandle tx) + throws SQLException + { + if (update.isEmpty()) { + return null; + } + return DbOperation.execute(tx, storage, c -> { + try (PreparedStatement ps = c.prepareStatement(prepareUpdateQuery(update))) { + prepareUpdateParameters(ps, id, update); + var rs = ps.executeQuery(); + if (rs.next()) { + return readTask(rs); + } + return null; + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + + @Override + @Nullable + public OperationTask updateLease(long id, Duration duration, @Nullable TransactionHandle tx) throws SQLException { + return DbOperation.execute(tx, storage, c -> { + try (PreparedStatement ps = c.prepareStatement(UPDATE_LEASE_QUERY)) { + ps.setString(1, duration.toString()); + ps.setLong(2, id); + var rs = ps.executeQuery(); + if (rs.next()) { + return readTask(rs); + } + return null; + } + }); + } + + @Nullable + @Override + public OperationTask insert(OperationTask operationTask, @Nullable TransactionHandle tx) throws SQLException { + return DbOperation.execute(tx, storage, c -> { + try (PreparedStatement ps = c.prepareStatement(INSERT_QUERY)) { + var i = 0; + ps.setString(++i, operationTask.name()); + ps.setString(++i, operationTask.entityId()); + ps.setString(++i, operationTask.type()); + ps.setString(++i, operationTask.status().name()); + ps.setString(++i, objectMapper.writeValueAsString(operationTask.metadata())); + ps.setString(++i, operationTask.operationId()); + var rs = ps.executeQuery(); + if (rs.next()) { + return readTask(rs); + } + return null; + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + }); + } + + @Override + public void delete(long id, @Nullable TransactionHandle tx) throws SQLException { + DbOperation.execute(tx, storage, c -> { + try (PreparedStatement ps = c.prepareStatement(DELETE_QUERY)) { + ps.setLong(1, id); + ps.executeUpdate(); + } + }); + } + + @Override + public List lockPendingBatch(String ownerId, Duration leaseDuration, int batchSize, + @Nullable TransactionHandle tx) throws SQLException + { + return DbOperation.execute(tx, storage, c -> { + try (PreparedStatement ps = c.prepareStatement(LOCK_PENDING_BATCH_QUERY)) { + var i = 0; + ps.setString(++i, ownerId); + ps.setString(++i, leaseDuration.toString()); + ps.setInt(++i, batchSize); + try (var rs = ps.executeQuery()) { + var result = new ArrayList(rs.getFetchSize()); + while (rs.next()) { + result.add(readTask(rs)); + } + return result; + } + } + }); + } + + @Override + public List recaptureOldTasks(String ownerId, Duration leaseTime, @Nullable TransactionHandle tx) + throws SQLException + { + return DbOperation.execute(tx, storage, c -> { + try (PreparedStatement ps = c.prepareStatement(RECAPTURE_OLD_TASKS_QUERY)) { + var i = 0; + ps.setString(++i, leaseTime.toString()); + ps.setString(++i, ownerId); + try (var rs = ps.executeQuery()) { + var result = new ArrayList(rs.getFetchSize()); + while (rs.next()) { + result.add(readTask(rs)); + } + return result; + } + } + }); + } + + @Nullable + @Override + public OperationTask tryLockTask(Long taskId, String entityId, String ownerId, Duration leaseDuration, + @Nullable TransactionHandle tx) throws SQLException + { + return DbOperation.execute(tx, storage, c -> { + try (PreparedStatement ps = c.prepareStatement(TRY_LOCK_TASK_QUERY)) { + var i = 0; + ps.setString(++i, ownerId); + ps.setString(++i, leaseDuration.toString()); + ps.setString(++i, entityId); + ps.setLong(++i, taskId); + try (var rs = ps.executeQuery()) { + if (rs.next()) { + return readTask(rs); + } + return null; + } + } + }); + } + + @Override + public List getAll(@Nullable TransactionHandle tx) throws SQLException { + return DbOperation.execute(tx, storage, c -> { + try (PreparedStatement ps = c.prepareStatement(GET_ALL_QUERY)) { + try (var rs = ps.executeQuery()) { + var result = new ArrayList(rs.getFetchSize()); + while (rs.next()) { + result.add(readTask(rs)); + } + return result; + } + } + }); + } + + private static String prepareUpdateQuery(OperationTask.Update update) { + var sb = new StringBuilder("UPDATE operation_task SET "); + if (update.status() != null) { + sb.append("status = cast(? as task_status), "); + } + if (update.metadata() != null) { + sb.append("metadata = cast(? as jsonb), "); + } + if (update.operationId() != null) { + sb.append("operation_id = ?, "); + } + sb.setLength(sb.length() - 2); + sb.append(" WHERE id = ? RETURNING ").append(FIELDS); + return sb.toString(); + } + + private void prepareUpdateParameters(PreparedStatement ps, long id, OperationTask.Update update) + throws JsonProcessingException, SQLException + { + var i = 0; + if (update.status() != null) { + ps.setString(++i, update.status().name()); + } + if (update.metadata() != null) { + ps.setString(++i, objectMapper.writeValueAsString(update.metadata())); + } + if (update.operationId() != null) { + ps.setString(++i, update.operationId()); + } + ps.setLong(++i, id); + } + + private OperationTask readTask(ResultSet rs) throws SQLException { + try { + var id = rs.getLong("id"); + var name = rs.getString("name"); + var entityId = rs.getString("entity_id"); + var type = rs.getString("type"); + var status = OperationTask.Status.valueOf(rs.getString("status")); + var createdAt = rs.getTimestamp("created_at").toInstant(); + var updatedAt = rs.getTimestamp("updated_at").toInstant(); + var metadata = objectMapper.readValue(rs.getString("metadata"), + MAP_TYPE_REFERENCE); + var operationId = rs.getString("operation_id"); + var workerId = rs.getString("worker_id"); + var leaseTillTs = rs.getTimestamp("lease_till"); + var leaseTill = leaseTillTs != null ? leaseTillTs.toInstant() : null; + return new OperationTask(id, name, entityId, type, status, createdAt, updatedAt, metadata, operationId, + workerId, + leaseTill); + } catch (JsonProcessingException e) { + throw new RuntimeException("Cannot read metadata for operation task object", e); + } + + } +} diff --git a/lzy/long-running/src/test/java/ai/lzy/longrunning/OperationTaskDaoImplTest.java b/lzy/long-running/src/test/java/ai/lzy/longrunning/OperationTaskDaoImplTest.java new file mode 100644 index 0000000000..c0a7da1a1d --- /dev/null +++ b/lzy/long-running/src/test/java/ai/lzy/longrunning/OperationTaskDaoImplTest.java @@ -0,0 +1,159 @@ +package ai.lzy.longrunning; + +import ai.lzy.longrunning.task.OperationTask; +import ai.lzy.longrunning.task.dao.OperationTaskDaoImpl; +import ai.lzy.model.db.StorageImpl; +import ai.lzy.model.db.test.DatabaseTestUtils; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.zonky.test.db.postgres.junit.EmbeddedPostgresRules; +import io.zonky.test.db.postgres.junit.PreparedDbRule; +import org.junit.*; + +import java.time.Duration; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class OperationTaskDaoImplTest { + + @Rule + public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(ds -> {}); + + private StorageImpl storage; + private OperationTaskDaoImpl taskDao; + + + @Before + public void setup() { + storage = new StorageImpl(DatabaseTestUtils.preparePostgresConfig(db.getConnectionInfo()), + "classpath:db/migrations") {}; + taskDao = new OperationTaskDaoImpl(storage, new ObjectMapper()); + } + + @After + public void teardown() { + DatabaseTestUtils.cleanup(storage); + storage.close(); + } + + @Test + public void create() throws Exception { + var task = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar"), null), null); + var fetched = taskDao.get(task.id(), null); + assertEquals(task, fetched); + } + + @Test + public void multiCreate() throws Exception { + var task1 = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar"), null), null); + var task2 = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar"), null), null); + var task3 = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar"), null), null); + Assert.assertTrue(task1.id() < task2.id()); + Assert.assertTrue(task2.id() < task3.id()); + } + + @Test + public void update() throws Exception { + var task = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar"), null), null); + var updated = taskDao.update(task.id(), OperationTask.Update.builder().status(OperationTask.Status.RUNNING).build(), null); + assertEquals(OperationTask.Status.RUNNING, updated.status()); + updated = taskDao.update(task.id(), OperationTask.Update.builder().operationId("42").build(), null); + assertEquals("42", updated.operationId()); + updated = taskDao.update(task.id(), OperationTask.Update.builder().metadata(Map.of("qux", "quux")).build(), null); + assertEquals(Map.of("qux", "quux"), updated.metadata()); + updated = taskDao.update(task.id(), OperationTask.Update.builder() + .status(OperationTask.Status.FINISHED) + .operationId("0") + .metadata(Map.of()) + .build(), null); + assertEquals(OperationTask.Status.FINISHED, updated.status()); + assertEquals("0", updated.operationId()); + assertEquals(Map.of(), updated.metadata()); + } + + @Test + public void delete() throws Exception { + var task = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar"), null), null); + taskDao.delete(task.id(), null); + } + + @Test + public void updateLease() throws Exception { + var task = taskDao.insert(OperationTask.createPending("foo", "bar", "MOUNT", Map.of("foo", "bar"), null), null); + var updateLease = taskDao.updateLease(task.id(), Duration.ofMinutes(5), null); + var between = Duration.between(task.createdAt(), updateLease.leaseTill()); + //leaseTill should be around 5 minutes from now + assertTrue(between.compareTo(Duration.ofMinutes(5)) >= 0); + } + + @Test + public void getUnknown() throws Exception { + var task = taskDao.get(42, null); + Assert.assertNull(task); + } + + @Test + public void updateUnknown() throws Exception { + var updated = taskDao.update(42, OperationTask.Update.builder().status(OperationTask.Status.RUNNING).build(), null); + Assert.assertNull(updated); + } + + @Test + public void deleteUnknown() throws Exception { + taskDao.delete(42, null); + } + + @Test + public void updateLeaseUnknown() throws Exception { + var updated = taskDao.updateLease(42, Duration.ofMinutes(5), null); + Assert.assertNull(updated); + } + + @Test + public void lockPendingBatch() throws Exception { + var task1 = taskDao.insert(OperationTask.createPending("task1", "1", "MOUNT", Map.of(), null), null); + var task2 = taskDao.insert(OperationTask.createPending("task2", "2", "MOUNT", Map.of(), null), null); + var task3 = taskDao.insert(OperationTask.createPending("task3", "3", "MOUNT", Map.of(), null), null); + var task4 = taskDao.insert(OperationTask.createPending("task4", "4", "MOUNT", Map.of(), null), null); + var task5 = taskDao.insert(OperationTask.createPending("task5", "1", "MOUNT", Map.of(), null), null); + var task6 = taskDao.insert(OperationTask.createPending("task6", "2", "MOUNT", Map.of(), null), null); + var task7 = taskDao.insert(OperationTask.createPending("task7", "3", "MOUNT", Map.of(), null), null); + var task8 = taskDao.insert(OperationTask.createPending("task8", "4", "MOUNT", Map.of(), null), null); + + task2 = taskDao.update(task2.id(), statusUpdate(OperationTask.Status.RUNNING), null); + task3 = taskDao.update(task3.id(), statusUpdate(OperationTask.Status.FINISHED), null); + task4 = taskDao.update(task4.id(), statusUpdate(OperationTask.Status.FAILED), null); + + var lockedTasks = taskDao.lockPendingBatch("worker", Duration.ofMinutes(5), 10, null); + var lockedTaskIds = lockedTasks.stream().map(OperationTask::id).collect(Collectors.toSet()); + assertEquals(Set.of(task1.id(), task7.id(), task8.id()), lockedTaskIds); + } + + @Test + public void lockPendingBatchWithAllRunning() throws Exception { + var task1 = taskDao.insert(OperationTask.createPending("task1", "1", "MOUNT", Map.of(), null), null); + var task2 = taskDao.insert(OperationTask.createPending("task2", "2", "MOUNT", Map.of(), null), null); + var task3 = taskDao.insert(OperationTask.createPending("task3", "3", "MOUNT", Map.of(), null), null); + var task4 = taskDao.insert(OperationTask.createPending("task4", "4", "MOUNT", Map.of(), null), null); + var task5 = taskDao.insert(OperationTask.createPending("task5", "1", "MOUNT", Map.of(), null), null); + var task6 = taskDao.insert(OperationTask.createPending("task6", "2", "MOUNT", Map.of(), null), null); + var task7 = taskDao.insert(OperationTask.createPending("task7", "3", "MOUNT", Map.of(), null), null); + var task8 = taskDao.insert(OperationTask.createPending("task8", "4", "MOUNT", Map.of(), null), null); + + task1 = taskDao.update(task1.id(), statusUpdate(OperationTask.Status.RUNNING), null); + task2 = taskDao.update(task2.id(), statusUpdate(OperationTask.Status.RUNNING), null); + task3 = taskDao.update(task3.id(), statusUpdate(OperationTask.Status.RUNNING), null); + task4 = taskDao.update(task4.id(), statusUpdate(OperationTask.Status.RUNNING), null); + + var lockedTasks = taskDao.lockPendingBatch("worker", Duration.ofMinutes(5), 10, null); + assertTrue(lockedTasks.isEmpty()); + } + + private OperationTask.Update statusUpdate(OperationTask.Status status) { + return OperationTask.Update.builder().status(status).build(); + } + +} diff --git a/lzy/long-running/src/test/java/ai/lzy/longrunning/task/OperationTaskSchedulerTest.java b/lzy/long-running/src/test/java/ai/lzy/longrunning/task/OperationTaskSchedulerTest.java new file mode 100644 index 0000000000..18957f3b64 --- /dev/null +++ b/lzy/long-running/src/test/java/ai/lzy/longrunning/task/OperationTaskSchedulerTest.java @@ -0,0 +1,259 @@ +package ai.lzy.longrunning.task; + +import ai.lzy.longrunning.Operation; +import ai.lzy.longrunning.OperationRunnerBase; +import ai.lzy.longrunning.OperationsExecutor; +import ai.lzy.longrunning.dao.OperationDaoImpl; +import ai.lzy.longrunning.task.dao.OperationTaskDaoImpl; +import ai.lzy.model.db.StorageImpl; +import ai.lzy.model.db.TransactionHandle; +import ai.lzy.model.db.test.DatabaseTestUtils; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.zonky.test.db.postgres.junit.EmbeddedPostgresRules; +import io.zonky.test.db.postgres.junit.PreparedDbRule; +import jakarta.annotation.Nullable; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; + +import java.sql.SQLException; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.locks.LockSupport; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static org.junit.Assert.*; + +public class OperationTaskSchedulerTest { + + public static final String MOUNT_TASK_TYPE = "MOUNT"; + public static final Duration SCHEDULER_DELAY = Duration.ofMinutes(5); + public static final Duration LEASE_DURATION = Duration.ofMinutes(5); + public static final int BATCH_SIZE = 7; + public static final int MAX_RUNNING_TASKS = 10; + public static final String WORKER_ID = "worker-42"; + + @Rule + public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(ds -> {}); + + @Rule + public Timeout timeout = Timeout.seconds(30); + + private StorageImpl storage; + private OperationTaskDaoImpl taskDao; + private OperationTaskScheduler taskScheduler; + private OperationDaoImpl operationDao; + private OperationsExecutor operationsExecutor; + private DispatchingOperationTaskResolver taskResolver; + + @Before + public void setup() { + storage = new StorageImpl(DatabaseTestUtils.preparePostgresConfig(db.getConnectionInfo()), + "classpath:db/migrations") {}; + taskDao = new OperationTaskDaoImpl(storage, new ObjectMapper()); + operationDao = new OperationDaoImpl(storage); + operationsExecutor = new OperationsExecutor(5, 10, () -> {}, e -> false); + taskResolver = new DispatchingOperationTaskResolver(List.of()); + taskScheduler = new OperationTaskScheduler(taskDao, operationsExecutor, taskResolver, Duration.ZERO, + SCHEDULER_DELAY, storage, new StubMetricsProvider(), WORKER_ID , LEASE_DURATION, + BATCH_SIZE, MAX_RUNNING_TASKS); + + } + + @After + public void teardown() { + DatabaseTestUtils.cleanup(storage); + storage.close(); + taskScheduler.shutdown(); + } + + @Test + public void schedulerWorkflow() { + taskScheduler.start(); + } + + @Test + public void schedulerCannotBeStartedTwice() { + taskScheduler.start(); + assertThrows(IllegalStateException.class, () -> taskScheduler.start()); + } + + @Test + public void schedulerShouldWork() throws SQLException { + taskResolver.addResolver(resolver(MOUNT_TASK_TYPE, () -> OperationRunnerBase.StepResult.FINISH, true)); + var op = createOperation(); + var task = taskScheduler.saveTask(OperationTask.createPending("Test", "foo", + MOUNT_TASK_TYPE, Map.of(), op.id()), null); + taskScheduler.start(); + op = waitForOperation(op); + assertNull(op.error()); + task = taskDao.get(task.id(), null); + assertNotNull(task); + assertEquals(OperationTask.Status.FINISHED, task.status()); + assertTrue(task.createdAt().isBefore(task.updatedAt())); + assertNotNull(task.leaseTill()); + assertTrue(task.createdAt().isBefore(task.leaseTill())); + assertEquals(WORKER_ID, task.workerId()); + } + + @Test + public void schedulerCanFail() throws SQLException { + taskResolver.addResolver(resolver(MOUNT_TASK_TYPE, () -> OperationRunnerBase.StepResult.FINISH, false)); + var op = createOperation(); + var task = taskScheduler.saveTask(OperationTask.createPending("Test", "foo", + MOUNT_TASK_TYPE, Map.of(), op.id()), null); + taskScheduler.start(); + op = waitForOperation(op); + assertNotNull(op.error()); + task = taskDao.get(task.id(), null); + assertNotNull(task); + assertEquals(OperationTask.Status.FAILED, task.status()); + assertTrue(task.createdAt().isBefore(task.updatedAt())); + assertNotNull(task.leaseTill()); + assertTrue(task.createdAt().isBefore(task.leaseTill())); + assertEquals(WORKER_ID, task.workerId()); + } + + @Test + public void schedulerWillLoadOnlyOneBatch() throws SQLException, InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + taskResolver.addResolver(resolver(MOUNT_TASK_TYPE, () -> { + latch.countDown(); + return OperationRunnerBase.StepResult.FINISH; + }, true)); + for (int i = 0; i < MAX_RUNNING_TASKS; i++) { + var op = createOperation(); + var task = taskScheduler.saveTask(OperationTask.createPending("Test", "foo" + i, + MOUNT_TASK_TYPE, Map.of(), op.id()), null); + } + taskScheduler.start(); + latch.await(); + var tasks = taskDao.getAll(null); + var tasksByPendingStatus = tasks.stream() + .collect(Collectors.partitioningBy(x -> x.status() == OperationTask.Status.PENDING)); + var pendingTasks = tasksByPendingStatus.get(true); + var runningTasks = tasksByPendingStatus.get(false); + assertEquals(MAX_RUNNING_TASKS - BATCH_SIZE, pendingTasks.size()); + assertEquals(BATCH_SIZE, runningTasks.size()); + } + + @Test + public void schedulerCanScheduleImmediateTask() throws SQLException, ExecutionException, InterruptedException { + taskResolver.addResolver(resolver(MOUNT_TASK_TYPE, () -> OperationRunnerBase.StepResult.FINISH, true)); + taskScheduler.start(); + + var op = createOperation(); + var task = taskScheduler.saveTask(OperationTask.createPending("Test", "foo", + MOUNT_TASK_TYPE, Map.of(), op.id()), null); + var scheduledFuture = taskScheduler.startImmediately(task); + var scheduled = scheduledFuture.get(); + assertTrue(scheduled); + + waitForOperation(op); + task = taskDao.get(task.id(), null); + assertNotNull(task); + assertEquals(OperationTask.Status.FINISHED, task.status()); + } + + @Test + public void schedulerCannotScheduleTwoImmediateTaskWithSameEntityId() + throws SQLException, ExecutionException, InterruptedException + { + CountDownLatch latch = new CountDownLatch(1); + taskResolver.addResolver(resolver(MOUNT_TASK_TYPE, () -> { + try { + latch.await(); + } catch (InterruptedException e) {} + return OperationRunnerBase.StepResult.FINISH; + }, true)); + taskScheduler.start(); + + var op1 = createOperation(); + var task1 = taskScheduler.saveTask(OperationTask.createPending("Test", "foo", + MOUNT_TASK_TYPE, Map.of(), op1.id()), null); + var op2 = createOperation(); + var task2 = taskScheduler.saveTask(OperationTask.createPending("Test", "foo", + MOUNT_TASK_TYPE, Map.of(), op2.id()), null); + var task1Future = taskScheduler.startImmediately(task1); + var task2Future = taskScheduler.startImmediately(task2); + assertTrue(task1Future.get()); + assertFalse(task2Future.get()); + + task1 = taskDao.get(task1.id(), null); + task2 = taskDao.get(task2.id(), null); + assertNotNull(task1); + assertNotNull(task2); + assertEquals(OperationTask.Status.RUNNING, task1.status()); + assertEquals(WORKER_ID, task1.workerId()); + assertEquals(OperationTask.Status.PENDING, task2.status()); + assertNull(task2.workerId()); + latch.countDown(); + } + + @Test + public void schedulerShouldRestartOldTasksOnStart() throws SQLException { + taskResolver.addResolver(resolver(MOUNT_TASK_TYPE, () -> OperationRunnerBase.StepResult.FINISH, true)); + var op1 = createOperation(); + var task1 = taskScheduler.saveTask(OperationTask.createPending("Test", "foo", + MOUNT_TASK_TYPE, Map.of(), op1.id()), null); + task1 = taskDao.tryLockTask(task1.id(), task1.entityId(), WORKER_ID, LEASE_DURATION, null); + assertNotNull(task1); + assertEquals(OperationTask.Status.RUNNING, task1.status()); + assertEquals(WORKER_ID, task1.workerId()); + + taskScheduler.start(); + waitForOperation(op1); + + task1 = taskDao.get(task1.id(), null); + assertNotNull(task1); + assertEquals(OperationTask.Status.FINISHED, task1.status()); + assertEquals(WORKER_ID, task1.workerId()); + } + + private Operation waitForOperation(Operation op) throws SQLException { + return waitForOperation(op.id()); + } + + private Operation waitForOperation(String opId) throws SQLException { + while (true) { + var op = operationDao.get(opId, null); + if (op == null) { + fail("Operation " + opId + " cannot be null"); + } + if (op.done()) { + return op; + } + LockSupport.parkNanos(Duration.ofSeconds(1).toNanos()); + } + } + + private Operation createOperation() throws SQLException { + var operation = Operation.create("foo", "op", Duration.ofDays(1), null, null); + operationDao.create(operation, null); + return operation; + } + + private TypedOperationTaskResolver resolver(String type, + Supplier action, + boolean completeOperation) { + return new TypedOperationTaskResolver() { + @Override + public String type() { + return type; + } + + @Override + public Result resolve(OperationTask task, @Nullable TransactionHandle tx) { + return Result.success(new TestAction(action, completeOperation, task, taskDao, + SCHEDULER_DELAY, task.operationId(), "Test action", storage, operationDao, + operationsExecutor, taskScheduler)); + } + }; + } +} diff --git a/lzy/long-running/src/test/java/ai/lzy/longrunning/task/StubMetricsProvider.java b/lzy/long-running/src/test/java/ai/lzy/longrunning/task/StubMetricsProvider.java new file mode 100644 index 0000000000..b1a259d722 --- /dev/null +++ b/lzy/long-running/src/test/java/ai/lzy/longrunning/task/StubMetricsProvider.java @@ -0,0 +1,38 @@ +package ai.lzy.longrunning.task; + +import io.prometheus.client.Gauge; + +public class StubMetricsProvider implements TaskMetricsProvider { + + @Override + public Gauge schedulerErrors() { + return Gauge.build() + .name("scheduler_errors") + .help("help") + .create(); + } + + @Override + public Gauge schedulerResolveErrors(OperationTaskResolver.Status status) { + return Gauge.build() + .name("scheduler_resolve_errors") + .help("help") + .create(); + } + + @Override + public Gauge queueSize() { + return Gauge.build() + .name("queue_size") + .help("help") + .create(); + } + + @Override + public Gauge runningTasks() { + return Gauge.build() + .name("running_tasks") + .help("help") + .create(); + } +} diff --git a/lzy/long-running/src/test/java/ai/lzy/longrunning/task/TestAction.java b/lzy/long-running/src/test/java/ai/lzy/longrunning/task/TestAction.java new file mode 100644 index 0000000000..38377744ae --- /dev/null +++ b/lzy/long-running/src/test/java/ai/lzy/longrunning/task/TestAction.java @@ -0,0 +1,51 @@ +package ai.lzy.longrunning.task; + +import ai.lzy.longrunning.OperationsExecutor; +import ai.lzy.longrunning.dao.OperationDao; +import ai.lzy.longrunning.task.dao.OperationTaskDao; +import ai.lzy.model.db.Storage; +import com.google.protobuf.Any; +import com.google.protobuf.Empty; +import io.grpc.Status; + +import java.sql.SQLException; +import java.time.Duration; +import java.util.List; +import java.util.function.Supplier; + +public class TestAction extends OpTaskAwareAction { + + private final Supplier action; + private final boolean completeOperation; + + public TestAction(Supplier action, boolean completeOperation, OperationTask operationTask, + OperationTaskDao operationTaskDao, Duration leaseDuration, String opId, String desc, + Storage storage, OperationDao operationsDao, + OperationsExecutor executor, + OperationTaskScheduler operationTaskScheduler) + { + super(operationTask, operationTaskDao, leaseDuration, opId, desc, storage, operationsDao, executor, + operationTaskScheduler); + this.action = action; + this.completeOperation = completeOperation; + } + + @Override + protected List> steps() { + return List.of(this::doSmth); + } + + private StepResult doSmth() { + var stepResult = action.get(); + try { + if (completeOperation) { + completeOperation(null, Any.pack(Empty.getDefaultInstance()), null); + } else { + failOperation(Status.INTERNAL, null); + } + } catch (SQLException e) { + log().error("{} Error while completing operation", logPrefix()); + } + return stepResult; + } +} diff --git a/lzy/long-running/src/test/resources/db/migrations/V1__task.sql b/lzy/long-running/src/test/resources/db/migrations/V1__task.sql new file mode 100644 index 0000000000..925cbf506e --- /dev/null +++ b/lzy/long-running/src/test/resources/db/migrations/V1__task.sql @@ -0,0 +1,44 @@ +CREATE TABLE operation +( + id TEXT NOT NULL PRIMARY KEY, + created_by TEXT NOT NULL, + created_at TIMESTAMP NOT NULL, + modified_at TIMESTAMP NOT NULL, + description TEXT NOT NULL, + deadline TIMESTAMP NULL, + done BOOLEAN NOT NULL, + + meta BYTEA NULL, + response BYTEA NULL, + error BYTEA NULL, + + idempotency_key TEXT NULL, + request_hash TEXT NULL, + + CHECK (((idempotency_key IS NOT NULL) AND (request_hash IS NOT NULL)) OR + ((idempotency_key IS NULL) AND (request_hash IS NULL))) +); + +CREATE UNIQUE INDEX idempotency_key_to_operation_index ON operation (idempotency_key); +CREATE UNIQUE INDEX failed_operations_index ON operation (id) WHERE done = TRUE AND error IS NOT NULL; +CREATE UNIQUE INDEX completed_operations_index ON operation (id) WHERE done = TRUE; + +CREATE TYPE task_status AS ENUM ('PENDING', 'RUNNING', 'FAILED', 'FINISHED'); + +CREATE TYPE task_type AS ENUM ('UNMOUNT', 'MOUNT'); + +CREATE TABLE IF NOT EXISTS operation_task( + id BIGSERIAL NOT NULL, + name TEXT NOT NULL, + entity_id TEXT NOT NULL, + type task_type NOT NULL, + status task_status NOT NULL, + created_at TIMESTAMP NOT NULL, + updated_at TIMESTAMP NOT NULL, + metadata JSONB NOT NULL, + operation_id TEXT, + worker_id TEXT, + lease_till TIMESTAMP, + PRIMARY KEY (id), + FOREIGN KEY (operation_id) REFERENCES operation(id) +); \ No newline at end of file diff --git a/lzy/long-running/src/test/resources/log4j2.yaml b/lzy/long-running/src/test/resources/log4j2.yaml new file mode 100644 index 0000000000..c1554d579e --- /dev/null +++ b/lzy/long-running/src/test/resources/log4j2.yaml @@ -0,0 +1,28 @@ +Configuration: + status: warn + + Appenders: + Console: + name: Console + target: SYSTEM_OUT + PatternLayout: + Pattern: "%d{yyyy-MM-dd HH:mm:ss.SSS}{UTC} [%t] %-5level %logger{36} %notEmpty{[rid=%X{rid}] }- %msg%n" + + Loggers: + Root: + level: warn + AppenderRef: + ref: Console + + Logger: + - name: UserEventLogs + level: error + + - name: "ai.lzy.model.utils.FreePortFinder" + level: error + + - name: io.zonky.test + level: warn + + - name: org.flywaydb + level: warn diff --git a/util/util-db/src/main/java/ai/lzy/model/db/test/DatabaseTestUtils.java b/util/util-db/src/main/java/ai/lzy/model/db/test/DatabaseTestUtils.java index 8cb35a7372..6bd24d3d62 100644 --- a/util/util-db/src/main/java/ai/lzy/model/db/test/DatabaseTestUtils.java +++ b/util/util-db/src/main/java/ai/lzy/model/db/test/DatabaseTestUtils.java @@ -1,5 +1,6 @@ package ai.lzy.model.db.test; +import ai.lzy.model.db.DatabaseConfiguration; import ai.lzy.model.db.Storage; import java.sql.SQLException; @@ -29,6 +30,28 @@ public static HashMap preparePostgresConfig(String app, Object c } } + public static DatabaseConfiguration preparePostgresConfig(Object ci) { + /* io.zonky.test.db.postgres.embedded.ConnectionInfo ci */ + try { + var user = getFieldValue(ci, "user", String.class); + assert "postgres".equals(user); + + var port = getFieldValue(ci, "port", Integer.class); + var dbName = getFieldValue(ci, "dbName", String.class); + + var result = new DatabaseConfiguration(); + result.setEnabled(true); + result.setUrl("jdbc:postgresql://localhost:%d/%s".formatted(port, dbName)); + result.setUsername("postgres"); + result.setPassword(""); + result.setMinPoolSize(1); + result.setMaxPoolSize(10); + return result; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + @SuppressWarnings("checkstyle:Indentation") public static HashMap prepareLocalhostConfig(String app) { return new HashMap<>() {{