Skip to content

Commit 8e516c8

Browse files
Merge branch 'TraceMachina:main' into main
2 parents fa2cdaa + 8c3bacb commit 8e516c8

File tree

15 files changed

+149
-14
lines changed

15 files changed

+149
-14
lines changed

nativelink-config/examples/worker_with_redis_scheduler.json5

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
worker_api_endpoint: {
7979
uri: "grpc://127.0.0.1:50061",
8080
},
81+
max_inflight_tasks: 5,
8182
cas_fast_slow_store: "WORKER_FAST_SLOW_STORE",
8283
upload_action_result: {
8384
ac_store: "AC_MAIN_STORE",

nativelink-config/src/cas_server.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,12 @@ pub struct LocalWorkerConfig {
727727
#[serde(default, deserialize_with = "convert_duration_with_shellexpand")]
728728
pub max_action_timeout: usize,
729729

730+
/// Maximum number of inflight tasks this worker can cope with.
731+
///
732+
/// Default: 0 (infinite tasks)
733+
#[serde(default, deserialize_with = "convert_numeric_with_shellexpand")]
734+
pub max_inflight_tasks: u64,
735+
730736
/// If timeout is handled in `entrypoint` or another wrapper script.
731737
/// If set to true `NativeLink` will not honor the timeout the action requested
732738
/// and instead will always force kill the action after `max_action_timeout`

nativelink-proto/com/github/trace_machina/nativelink/remote_execution/worker_api.proto

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,11 @@ message ConnectWorkerRequest {
7171
/// append this prefix to the assigned worker_id followed by a UUIDv6.
7272
string worker_id_prefix = 2;
7373

74-
reserved 3; // NextId.
74+
/// Maximum number of inflight tasks this worker can cope with at one time
75+
/// The default (0) means unlimited.
76+
uint64 max_inflight_tasks = 3;
77+
78+
reserved 4; // NextId.
7579
}
7680

7781
/// The result of an ExecutionRequest.

nativelink-proto/genproto/com.github.trace_machina.nativelink.remote_execution.pb.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ pub struct ConnectWorkerRequest {
4242
/// / append this prefix to the assigned worker_id followed by a UUIDv6.
4343
#[prost(string, tag = "2")]
4444
pub worker_id_prefix: ::prost::alloc::string::String,
45+
/// / Maximum number of inflight tasks this worker can cope with at one time
46+
/// / The default (0) means unlimited.
47+
#[prost(uint64, tag = "3")]
48+
pub max_inflight_tasks: u64,
4549
}
4650
/// / The result of an ExecutionRequest.
4751
#[derive(Clone, PartialEq, ::prost::Message)]

nativelink-scheduler/src/api_worker_scheduler.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,11 @@ impl ApiWorkerSchedulerImpl {
259259
if !w.can_accept_work() {
260260
if full_worker_logging {
261261
info!(
262-
"Worker {worker_id} cannot accept work: is_paused={}, is_draining={}",
263-
w.is_paused, w.is_draining
262+
"Worker {worker_id} cannot accept work: is_paused={}, is_draining={}, inflight={}/{}",
263+
w.is_paused,
264+
w.is_draining,
265+
w.running_action_infos.len(),
266+
w.max_inflight_tasks
264267
);
265268
}
266269
return false;

nativelink-scheduler/src/worker.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
use core::hash::{Hash, Hasher};
16+
use core::u64;
1617
use std::collections::{HashMap, HashSet};
1718
use std::sync::Arc;
1819
use std::time::{SystemTime, UNIX_EPOCH};
@@ -95,6 +96,10 @@ pub struct Worker {
9596
#[metric(help = "If the worker is draining.")]
9697
pub is_draining: bool,
9798

99+
/// Maximum inflight tasks for this worker (or 0 for unlimited)
100+
#[metric(help = "Maximum inflight tasks for this worker (or 0 for unlimited)")]
101+
pub max_inflight_tasks: u64,
102+
98103
/// Stats about the worker.
99104
#[metric]
100105
metrics: Arc<Metrics>,
@@ -134,6 +139,7 @@ impl Worker {
134139
platform_properties: PlatformProperties,
135140
tx: UnboundedSender<UpdateForWorker>,
136141
timestamp: WorkerTimestamp,
142+
max_inflight_tasks: u64,
137143
) -> Self {
138144
Self {
139145
id,
@@ -144,6 +150,7 @@ impl Worker {
144150
last_update_timestamp: timestamp,
145151
is_paused: false,
146152
is_draining: false,
153+
max_inflight_tasks,
147154
metrics: Arc::new(Metrics {
148155
connected_timestamp: SystemTime::now()
149156
.duration_since(UNIX_EPOCH)
@@ -270,8 +277,12 @@ impl Worker {
270277
}
271278
}
272279

273-
pub const fn can_accept_work(&self) -> bool {
274-
!self.is_paused && !self.is_draining
280+
pub fn can_accept_work(&self) -> bool {
281+
!self.is_paused
282+
&& !self.is_draining
283+
&& (self.max_inflight_tasks == 0
284+
|| u64::try_from(self.running_action_infos.len()).unwrap_or(u64::MAX)
285+
< self.max_inflight_tasks)
275286
}
276287
}
277288

nativelink-scheduler/tests/redis_store_awaited_action_db_test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ async fn setup_new_worker(
369369
props: PlatformProperties,
370370
) -> Result<mpsc::UnboundedReceiver<UpdateForWorker>, Error> {
371371
let (tx, mut rx) = mpsc::unbounded_channel();
372-
let worker = Worker::new(worker_id.clone(), props, tx, NOW_TIME);
372+
let worker = Worker::new(worker_id.clone(), props, tx, NOW_TIME, 0);
373373
scheduler
374374
.add_worker(worker)
375375
.await

nativelink-scheduler/tests/simple_scheduler_test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ async fn setup_new_worker(
9292
props: PlatformProperties,
9393
) -> Result<mpsc::UnboundedReceiver<UpdateForWorker>, Error> {
9494
let (tx, mut rx) = mpsc::unbounded_channel();
95-
let worker = Worker::new(worker_id.clone(), props, tx, NOW_TIME);
95+
let worker = Worker::new(worker_id.clone(), props, tx, NOW_TIME, 0);
9696
scheduler
9797
.add_worker(worker)
9898
.await

nativelink-service/src/worker_api_server.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ impl WorkerApiServer {
189189
platform_properties,
190190
tx,
191191
(self.now_fn)()?.as_secs(),
192+
connect_worker_request.max_inflight_tasks,
192193
);
193194
self.scheduler
194195
.add_worker(worker)

nativelink-service/tests/worker_api_server_test.rs

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ use nativelink_util::action_messages::{
4545
use nativelink_util::common::DigestInfo;
4646
use nativelink_util::digest_hasher::DigestHasherFunc;
4747
use nativelink_util::operation_state_manager::{UpdateOperationType, WorkerStateManager};
48+
use nativelink_util::platform_properties::PlatformProperties;
4849
use pretty_assertions::assert_eq;
4950
use tokio::join;
5051
use tokio::sync::{Notify, mpsc};
@@ -143,6 +144,14 @@ const fn static_now_fn() -> Result<Duration, Error> {
143144
}
144145

145146
async fn setup_api_server(worker_timeout: u64, now_fn: NowFn) -> Result<TestContext, Error> {
147+
setup_api_server_with_task_limit(worker_timeout, now_fn, 0).await
148+
}
149+
150+
async fn setup_api_server_with_task_limit(
151+
worker_timeout: u64,
152+
now_fn: NowFn,
153+
max_worker_tasks: u64,
154+
) -> Result<TestContext, Error> {
146155
const SCHEDULER_NAME: &str = "DUMMY_SCHEDULE_NAME";
147156

148157
const UUID_SIZE: usize = 36;
@@ -172,7 +181,10 @@ async fn setup_api_server(worker_timeout: u64, now_fn: NowFn) -> Result<TestCont
172181
)
173182
.err_tip(|| "Error creating WorkerApiServer")?;
174183

175-
let connect_worker_request = ConnectWorkerRequest::default();
184+
let connect_worker_request = ConnectWorkerRequest {
185+
max_inflight_tasks: max_worker_tasks,
186+
..Default::default()
187+
};
176188
let (tx, rx) = mpsc::channel(1);
177189
tx.send(Update::ConnectWorkerRequest(connect_worker_request))
178190
.await
@@ -545,3 +557,75 @@ pub async fn execution_response_success_test() -> Result<(), Box<dyn core::error
545557
}
546558
Ok(())
547559
}
560+
561+
#[nativelink_test]
562+
pub async fn workers_only_allow_max_tasks() -> Result<(), Box<dyn core::error::Error>> {
563+
let test_context =
564+
setup_api_server_with_task_limit(BASE_WORKER_TIMEOUT_S, Box::new(static_now_fn), 1).await?;
565+
566+
let selected_worker = test_context
567+
.scheduler
568+
.find_worker_for_action(&PlatformProperties::new(HashMap::new()), true)
569+
.await;
570+
assert_eq!(
571+
selected_worker,
572+
Some(test_context.worker_id.clone()),
573+
"Expected worker to permit tasks to begin with"
574+
);
575+
576+
let action_digest = DigestInfo::new([7u8; 32], 123);
577+
let instance_name = "instance_name".to_string();
578+
579+
let unique_qualifier = ActionUniqueQualifier::Uncacheable(ActionUniqueKey {
580+
instance_name: instance_name.clone(),
581+
digest_function: DigestHasherFunc::Sha256,
582+
digest: action_digest,
583+
});
584+
585+
let action_info = Arc::new(ActionInfo {
586+
command_digest: DigestInfo::new([0u8; 32], 0),
587+
input_root_digest: DigestInfo::new([0u8; 32], 0),
588+
timeout: Duration::MAX,
589+
platform_properties: HashMap::new(),
590+
priority: 0,
591+
load_timestamp: make_system_time(0),
592+
insert_timestamp: make_system_time(0),
593+
unique_qualifier,
594+
});
595+
596+
let platform_properties = test_context
597+
.scheduler
598+
.get_platform_property_manager()
599+
.make_platform_properties(action_info.platform_properties.clone())
600+
.err_tip(|| "Failed to make platform properties in SimpleScheduler::do_try_match")?;
601+
602+
let expected_operation_id = OperationId::default();
603+
604+
test_context
605+
.scheduler
606+
.worker_notify_run_action(
607+
test_context.worker_id.clone(),
608+
expected_operation_id,
609+
ActionInfoWithProps {
610+
inner: action_info,
611+
platform_properties,
612+
},
613+
)
614+
.await
615+
.unwrap();
616+
617+
let selected_worker = test_context
618+
.scheduler
619+
.find_worker_for_action(&PlatformProperties::new(HashMap::new()), true)
620+
.await;
621+
assert_eq!(
622+
selected_worker, None,
623+
"Expected not to be able to give worker a second task"
624+
);
625+
626+
assert!(logs_contain(
627+
"cannot accept work: is_paused=false, is_draining=false, inflight=1/1"
628+
));
629+
630+
Ok(())
631+
}

0 commit comments

Comments
 (0)