Skip to content

Commit 57134f5

Browse files
authored
Fix actor slow startup times (#621)
## Overview Somewhere along the way a regression on actor performance occurred (probably during my large refactor). Cold-start times are extremely large as a result of reporting 0 capacity for the initial heartbeat. This means we need to wait for the next heartbeat for an updated capacity report to assign any tasks to the worker. The problem is that the executor startup was in a race condition with the heartbeat mechanism, so if the rust worker heartbeat before executors were initialized then the capacity was 0. This PR adds a future over a boolean to determine whether the runtime has initialized before starting the heartbeat runtime. ## Test Plan Ran this locally and decreased cold-start for "hello world" from 16s to 2s. Going to run some tests on cloud to make sure. ## Rollout Plan (if applicable) This may be rolled out immediately. We will need to cut a new `union` package. ## Upstream Changes Should this change be upstreamed to OSS (flyteorg/flyte)? If not, please uncheck this box, which is used for auditing. Note, it is the responsibility of each developer to actually upstream their changes. See [this guide](https://unionai.atlassian.net/wiki/spaces/ENG/pages/447610883/Flyte+-+Union+Cloud+Development+Runbook/#When-are-versions-updated%3F). - [ ] To be upstreamed to OSS ## Issue fixes https://linear.app/unionai/issue/COR-2673/fix-startup-way-slower-than-regular-tasks ## Checklist * [ ] Added tests * [ ] Ran a deploy dry run and shared the terraform plan * [ ] Added logging and metrics * [ ] Updated [dashboards](https://unionai.grafana.net/dashboards) and [alerts](https://unionai.grafana.net/alerting/list) * [ ] Updated documentation
1 parent 01553a6 commit 57134f5

File tree

7 files changed

+105
-66
lines changed

7 files changed

+105
-66
lines changed

fasttask/plugin/plugin.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ const maxErrorMessageLength = 102400 // 100kb
4646
var (
4747
statusUpdateNotFoundError = errors.New("StatusUpdateNotFound")
4848
taskContextNotFoundError = errors.New("TaskContextNotFound")
49-
podContainerNotFound = errors.New("PodContainerNotFound")
49+
podContainerNotFoundError = errors.New("PodContainerNotFound")
5050

5151
taskStartTimeTemplateVar = tasklog.MustCreateRegex("taskStartTime")
5252
taskStartTimeUnixMsTemplateVar = tasklog.MustCreateRegex("taskStartTimeUnixMs")
@@ -394,7 +394,7 @@ func (p *Plugin) trySubmitTask(ctx context.Context, tCtx core.TaskExecutionConte
394394
pluginState.LastUpdated = now
395395

396396
taskInfo, err := p.getTaskInfo(ctx, tCtx, initialState.SubmittedAt, time.Now(), executionEnv, queueID, workerID)
397-
if err != nil {
397+
if err != nil && !errors.Is(err, podContainerNotFoundError) {
398398
return nil, core.PhaseInfoUndefined, err
399399
}
400400
phaseInfo = core.PhaseInfoQueuedWithTaskInfo(now, pluginState.PhaseVersion, fmt.Sprintf("task offered to worker %s", workerID), taskInfo)
@@ -557,7 +557,7 @@ func (p *Plugin) getTaskInfo(ctx context.Context, tCtx core.TaskExecutionContext
557557
// an in-memory store the may occur during restarts.
558558
// `pod == nil` may occur if it has not yet been populated in the kubeclient cache or was deleted
559559
logger.Warnf(ctx, "Worker %q not found (exists=%t) in status map for queue %q", workerID, ok, queueID)
560-
return &taskInfo, podContainerNotFound
560+
return &taskInfo, podContainerNotFoundError
561561
}
562562

563563
containerIndex := -1
@@ -569,7 +569,7 @@ func (p *Plugin) getTaskInfo(ctx context.Context, tCtx core.TaskExecutionContext
569569
}
570570
if containerIndex == -1 {
571571
logger.Warnf(ctx, "Container %q not found in pod %q", pod.GetName(), pod.GetName())
572-
return &taskInfo, podContainerNotFound
572+
return &taskInfo, podContainerNotFoundError
573573
}
574574

575575
taskInfo.LogContext = &idlcore.LogContext{
@@ -594,7 +594,7 @@ func (p *Plugin) getTaskInfo(ctx context.Context, tCtx core.TaskExecutionContext
594594

595595
if len(pod.Status.ContainerStatuses) <= containerIndex || pod.Status.ContainerStatuses[containerIndex].ContainerID == "" {
596596
// no container id yet
597-
return &taskInfo, podContainerNotFound
597+
return &taskInfo, podContainerNotFoundError
598598
}
599599

600600
taskTemplate, err := tCtx.TaskReader().Read(ctx)

fasttask/plugin/plugin_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ func TestHandleRunning(t *testing.T) {
665665
expectedPhase: core.PhaseUndefined,
666666
expectedPhaseVersion: 0,
667667
expectedReason: "",
668-
expectedError: podContainerNotFound,
668+
expectedError: podContainerNotFoundError,
669669
expectedLastUpdatedInc: false,
670670
expectedLogs: false,
671671
},
@@ -1002,7 +1002,7 @@ func TestGetTaskInfo(t *testing.T) {
10021002

10031003
taskInfo, err := plugin.getTaskInfo(ctx, tCtx, start, now, executionEnv, queueID, workerID)
10041004

1005-
assert.Equal(t, podContainerNotFound, err)
1005+
assert.Equal(t, podContainerNotFoundError, err)
10061006
assert.Empty(t, taskInfo.Logs)
10071007
assert.Nil(t, taskInfo.LogContext)
10081008
})
@@ -1050,7 +1050,7 @@ func TestGetTaskInfo(t *testing.T) {
10501050

10511051
taskInfo, err := plugin.getTaskInfo(ctx, tCtx, start, now, executionEnv, queueID, workerID)
10521052

1053-
assert.Equal(t, podContainerNotFound, err)
1053+
assert.Equal(t, podContainerNotFoundError, err)
10541054
assert.Empty(t, taskInfo.Logs)
10551055
assert.Nil(t, taskInfo.LogContext)
10561056
})
@@ -1098,7 +1098,7 @@ func TestGetTaskInfo(t *testing.T) {
10981098

10991099
taskInfo, err := plugin.getTaskInfo(ctx, tCtx, start, now, executionEnv, queueID, workerID)
11001100

1101-
assert.Equal(t, podContainerNotFound, err)
1101+
assert.Equal(t, podContainerNotFoundError, err)
11021102
assert.Empty(t, taskInfo.Logs)
11031103
assert.Equal(t, expectedLogCtx, taskInfo.LogContext)
11041104
})

fasttask/plugin/service.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ func (f *fastTaskServiceImpl) OfferOnQueue(ctx context.Context, queueID, taskID,
384384
}
385385

386386
// create task status channel
387-
f.taskStatusChannels.Store(taskID, make(chan *workerTaskStatus, GetConfig().TaskStatusBufferSize))
387+
f.taskStatusChannels.LoadOrStore(taskID, make(chan *workerTaskStatus, GetConfig().TaskStatusBufferSize))
388388
return worker.workerID, nil
389389
}
390390

fasttask/worker/bridge/src/bridge.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
use std::sync::{Arc, Mutex};
12
use std::time::Duration;
23

4+
use crate::common::{AsyncBool, AsyncBoolFuture};
35
use crate::connection::{ConnectionBuilder, ConnectionRuntime};
46
use crate::heartbeater::{HeartbeatRuntime, Heartbeater};
57
use crate::manager::{CapacityReporter, TaskManager, TaskManagerRuntime};
@@ -19,16 +21,25 @@ pub async fn run<T: ConnectionBuilder, U: Heartbeater + Send, V: TaskManager>(
1921
let (task_status_tx, task_status_rx) = async_channel::unbounded();
2022

2123
// initialize and start manager
24+
let manager_runtime_ready = Arc::new(Mutex::new(AsyncBool::new()));
25+
let manager_runtime_ready_clone = manager_runtime_ready.clone();
26+
2227
let manager_runtime = manager.get_runtime()?; // TODO @hamersaw - handle error
2328
let _manager_handle = tokio::spawn(async move {
2429
// currently panicking if manager runtime fails rather than attempting to restart. this will
2530
// effectively force a new replica and failover tasks. a manager runtime failure should
2631
// only occur as a bug.
27-
if let Err(e) = manager_runtime.run(task_status_tx).await {
32+
if let Err(e) = manager_runtime
33+
.run(manager_runtime_ready_clone, task_status_tx)
34+
.await
35+
{
2836
panic!("manager failed: {}", e);
2937
}
3038
});
3139

40+
let manager_runtime_future = AsyncBoolFuture::new(manager_runtime_ready);
41+
manager_runtime_future.await;
42+
3243
// start heartbeater
3344
let heartbeat_runtime = heartbeater.get_runtime()?; // TODO @hamersaw - handle error
3445
let _heartbeat_handle = tokio::spawn(async move {

fasttask/worker/bridge/src/common.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
use std::collections::HashMap;
2+
use std::future::Future;
3+
use std::pin::Pin;
4+
use std::sync::{Arc, Mutex};
5+
use std::task::{Context, Poll, Waker};
26

37
use async_channel::Sender;
48
use serde::{Deserialize, Serialize};
@@ -35,3 +39,51 @@ pub struct Response {
3539

3640
pub executor_corrupt: bool,
3741
}
42+
43+
pub struct AsyncBoolFuture {
44+
async_bool: Arc<Mutex<AsyncBool>>,
45+
}
46+
47+
impl AsyncBoolFuture {
48+
pub fn new(async_bool: Arc<Mutex<AsyncBool>>) -> Self {
49+
Self { async_bool }
50+
}
51+
}
52+
53+
impl Future for AsyncBoolFuture {
54+
type Output = ();
55+
56+
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<()> {
57+
let mut async_bool = self.async_bool.lock().unwrap();
58+
if async_bool.value {
59+
async_bool.value = false;
60+
return Poll::Ready(());
61+
}
62+
63+
let waker = ctx.waker().clone();
64+
async_bool.waker = Some(waker);
65+
66+
Poll::Pending
67+
}
68+
}
69+
70+
pub struct AsyncBool {
71+
value: bool,
72+
waker: Option<Waker>,
73+
}
74+
75+
impl AsyncBool {
76+
pub fn new() -> Self {
77+
Self {
78+
value: false,
79+
waker: None,
80+
}
81+
}
82+
83+
pub fn trigger(&mut self) {
84+
self.value = true;
85+
if let Some(waker) = &self.waker {
86+
waker.clone().wake();
87+
}
88+
}
89+
}

fasttask/worker/bridge/src/heartbeater.rs

Lines changed: 2 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
use std::future::Future;
2-
use std::pin::Pin;
31
use std::sync::{Arc, Mutex, RwLock};
4-
use std::task::{Context, Poll, Waker};
52
use std::time::Duration;
63

74
use anyhow::Result;
85
use async_channel::{Receiver, Sender};
96
use tokio::time::Interval;
107

11-
use crate::common::{FAILED, SUCCEEDED};
8+
use crate::common::{AsyncBool, AsyncBoolFuture, FAILED, SUCCEEDED};
129
use crate::manager::CapacityReporter;
1310
use crate::pb::fasttask::{HeartbeatRequest, TaskStatus};
1411

@@ -114,59 +111,14 @@ impl HeartbeatRuntime for PeriodicHeartbeatRuntime {
114111
}
115112
}
116113

117-
struct AsyncBoolFuture {
118-
async_bool: Arc<Mutex<AsyncBool>>,
119-
}
120-
121-
impl Future for AsyncBoolFuture {
122-
type Output = ();
123-
124-
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<()> {
125-
let mut async_bool = self.async_bool.lock().unwrap();
126-
if async_bool.value {
127-
async_bool.value = false;
128-
return Poll::Ready(());
129-
}
130-
131-
let waker = ctx.waker().clone();
132-
async_bool.waker = Some(waker);
133-
134-
Poll::Pending
135-
}
136-
}
137-
138-
struct AsyncBool {
139-
value: bool,
140-
waker: Option<Waker>,
141-
}
142-
143-
impl AsyncBool {
144-
fn new() -> Self {
145-
Self {
146-
value: false,
147-
waker: None,
148-
}
149-
}
150-
151-
fn trigger(&mut self) {
152-
self.value = true;
153-
if let Some(waker) = &self.waker {
154-
waker.clone().wake();
155-
}
156-
}
157-
}
158-
159114
struct HeartbeatTrigger {
160115
interval: Interval,
161116
async_bool: Arc<Mutex<AsyncBool>>,
162117
}
163118

164119
impl HeartbeatTrigger {
165120
async fn trigger(&mut self) -> () {
166-
let async_bool_future = AsyncBoolFuture {
167-
async_bool: self.async_bool.clone(),
168-
};
169-
121+
let async_bool_future = AsyncBoolFuture::new(self.async_bool.clone());
170122
tokio::select! {
171123
_ = self.interval.tick() => {},
172124
_ = async_bool_future => {},

fasttask/worker/bridge/src/manager.rs

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use tokio::process::Command;
1111
use tokio_util::codec::{Framed, LengthDelimitedCodec};
1212
use tracing::warn;
1313

14-
use crate::common::{Executor, TaskContext, SUCCEEDED};
14+
use crate::common::{AsyncBool, Executor, TaskContext, SUCCEEDED};
1515
use crate::pb::fasttask::{Capacity, TaskStatus};
1616
use crate::task::{self};
1717

@@ -32,7 +32,11 @@ pub trait TaskManager {
3232

3333
#[trait_variant::make(TaskManagerRuntime: Send)]
3434
pub trait LocalTaskManagerRuntime {
35-
async fn run(&self, task_status_tx: Sender<TaskStatus>) -> Result<()>;
35+
async fn run(
36+
&self,
37+
ready: Arc<Mutex<AsyncBool>>,
38+
task_status_tx: Sender<TaskStatus>,
39+
) -> Result<()>;
3640
}
3741

3842
pub struct CapacityReporter {
@@ -316,7 +320,11 @@ pub struct MultiProcessRuntime {
316320
}
317321

318322
impl TaskManagerRuntime for MultiProcessRuntime {
319-
async fn run(&self, task_status_tx: Sender<TaskStatus>) -> Result<()> {
323+
async fn run(
324+
&self,
325+
ready: Arc<Mutex<AsyncBool>>,
326+
task_status_tx: Sender<TaskStatus>,
327+
) -> Result<()> {
320328
let (backlog_tx, backlog_rx) = (self.backlog_tx.clone(), self.backlog_rx.clone());
321329
let (executor_tx, executor_rx) = (self.executor_tx.clone(), self.executor_rx.clone());
322330
let (task_assignment_rx, task_contexts) =
@@ -349,6 +357,12 @@ impl TaskManagerRuntime for MultiProcessRuntime {
349357
self.executor_tx.send(executor).await?;
350358

351359
index += 1;
360+
361+
// trigger ready if all executors are initialized
362+
if index == self.parallelism {
363+
let mut ready = ready.lock().unwrap();
364+
ready.trigger();
365+
}
352366
},
353367
task_assignment_result = task_assignment_rx.recv() => {
354368
let task_assignment= task_assignment_result?;
@@ -507,7 +521,16 @@ pub struct SuccessRuntime {
507521
}
508522

509523
impl TaskManagerRuntime for SuccessRuntime {
510-
async fn run(&self, task_status_tx: Sender<TaskStatus>) -> Result<()> {
524+
async fn run(
525+
&self,
526+
ready: Arc<Mutex<AsyncBool>>,
527+
task_status_tx: Sender<TaskStatus>,
528+
) -> Result<()> {
529+
{
530+
let mut ready = ready.lock().unwrap();
531+
ready.trigger();
532+
}
533+
511534
let task_rx = self.task_rx.clone();
512535
loop {
513536
let task_result = task_rx.recv().await;
@@ -563,8 +586,9 @@ mod tests {
563586
assert!(manager_runtime_result.is_ok());
564587
let manager_runtime = manager_runtime_result.unwrap();
565588

589+
let ready = Arc::new(Mutex::new(AsyncBool::new()));
566590
let manager_handle = tokio::spawn(async move {
567-
super::TaskManagerRuntime::run(&manager_runtime, task_status_tx).await
591+
super::TaskManagerRuntime::run(&manager_runtime, ready, task_status_tx).await
568592
});
569593

570594
// validate get capacity works

0 commit comments

Comments
 (0)