Skip to content

Commit 2d82f56

Browse files
committed
Wire LM planner into generation-daemon
Add --use-lm flag to load the LM planner on startup and run it before each generation request. The LM output fills in BPM, key/scale, time signature, and language; user-specified duration_s always takes priority over the LM suggestion. The VRAM offload threshold is automatically lowered when the LM is resident to avoid spurious CPU fallback. Also bind the socket before model loading so the file appears within ~1s of startup rather than after the full 30s load time.
1 parent 3c85a67 commit 2d82f56

File tree

1 file changed

+118
-27
lines changed

1 file changed

+118
-27
lines changed

src/bin/generation-daemon.rs

Lines changed: 118 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,16 @@
4141
//! ```
4242
4343
use std::path::PathBuf;
44+
use std::sync::{Arc, Mutex};
4445

4546
use ace_step_rs::{
4647
audio::write_audio,
4748
manager::{GenerationManager, ManagerConfig},
49+
model::lm_planner::LmPlanner,
4850
pipeline::GenerationParams,
4951
};
5052
use clap::Parser;
53+
use hf_hub::api::sync::Api;
5154
use serde::{Deserialize, Serialize};
5255
use tokio::{
5356
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
@@ -69,6 +72,13 @@ struct Args {
6972
/// CUDA device ordinal (0 = first GPU).
7073
#[arg(long, default_value_t = 0)]
7174
device: usize,
75+
76+
/// Load the 5Hz LM planner and use it to expand captions before generation.
77+
///
78+
/// Adds ~3.5GB VRAM usage. When enabled, the LM rewrites the caption and
79+
/// fills in BPM, key/scale, time signature, and duration from the text.
80+
#[arg(long, default_value_t = false)]
81+
use_lm: bool,
7282
}
7383

7484
// ── Wire types ───────────────────────────────────────────────────────────────
@@ -87,8 +97,9 @@ struct Request {
8797
#[serde(default = "default_language")]
8898
language: String,
8999

90-
#[serde(default = "default_duration")]
91-
duration_s: f64,
100+
/// Duration in seconds. When absent the LM planner may suggest one; otherwise defaults to 30.
101+
#[serde(default)]
102+
duration_s: Option<f64>,
92103

93104
#[serde(default = "default_shift")]
94105
shift: f64,
@@ -105,9 +116,7 @@ struct Request {
105116
fn default_language() -> String {
106117
"en".into()
107118
}
108-
fn default_duration() -> f64 {
109-
30.0
110-
}
119+
111120
fn default_shift() -> f64 {
112121
3.0
113122
}
@@ -167,22 +176,53 @@ async fn main() -> anyhow::Result<()> {
167176
std::fs::remove_file(&args.socket)?;
168177
}
169178

170-
tracing::info!("Loading ACE-Step pipeline (this may take a minute on first run)...");
179+
// Bind the socket immediately so callers can connect right away.
180+
// Connections that arrive before loading completes will wait in the channel.
181+
let listener = UnixListener::bind(&args.socket)?;
182+
tracing::info!("Listening on {:?} (loading pipeline...)", args.socket);
183+
184+
// When the LM planner is resident it consumes ~3.5GB, so the pipeline
185+
// itself only needs ~6.3GB — leave 512MB headroom instead of the default 2GB.
186+
let min_free_vram_bytes = if args.use_lm {
187+
512 * 1024 * 1024
188+
} else {
189+
ManagerConfig::default().min_free_vram_bytes
190+
};
171191
let config = ManagerConfig {
172192
cuda_device: args.device,
193+
min_free_vram_bytes,
173194
..ManagerConfig::default()
174195
};
175196
let manager = GenerationManager::start(config).await?;
176-
tracing::info!("Pipeline ready. Listening on {:?}", args.socket);
177197

178-
let listener = UnixListener::bind(&args.socket)?;
198+
// Optionally load the LM planner (blocking, on the current thread).
199+
let lm_planner: Option<Arc<Mutex<LmPlanner>>> = if args.use_lm {
200+
tracing::info!("Loading 5Hz LM planner...");
201+
let device = ace_step_rs::manager::preferred_device(args.device);
202+
let planner = tokio::task::spawn_blocking(move || -> anyhow::Result<LmPlanner> {
203+
let api = Api::new()?;
204+
let repo = api.model("ACE-Step/Ace-Step1.5".to_string());
205+
let weights = repo.get("acestep-5Hz-lm-1.7B/model.safetensors")?;
206+
let tokenizer = repo.get("acestep-5Hz-lm-1.7B/tokenizer.json")?;
207+
let planner = LmPlanner::load(&weights, &tokenizer, &device, candle_core::DType::BF16)?;
208+
Ok(planner)
209+
})
210+
.await??;
211+
tracing::info!("LM planner ready");
212+
Some(Arc::new(Mutex::new(planner)))
213+
} else {
214+
None
215+
};
216+
217+
tracing::info!("Pipeline ready");
179218

180219
loop {
181220
match listener.accept().await {
182221
Ok((stream, _addr)) => {
183222
let manager = manager.clone();
223+
let lm = lm_planner.clone();
184224
tokio::spawn(async move {
185-
if let Err(e) = handle_connection(stream, manager).await {
225+
if let Err(e) = handle_connection(stream, manager, lm).await {
186226
tracing::warn!("connection error: {e}");
187227
}
188228
});
@@ -196,7 +236,11 @@ async fn main() -> anyhow::Result<()> {
196236

197237
// ── Connection handler ────────────────────────────────────────────────────────
198238

199-
async fn handle_connection(stream: UnixStream, manager: GenerationManager) -> anyhow::Result<()> {
239+
async fn handle_connection(
240+
stream: UnixStream,
241+
manager: GenerationManager,
242+
lm: Option<Arc<Mutex<LmPlanner>>>,
243+
) -> anyhow::Result<()> {
200244
let (reader, mut writer) = stream.into_split();
201245
let mut lines = BufReader::new(reader).lines();
202246

@@ -209,12 +253,16 @@ async fn handle_connection(stream: UnixStream, manager: GenerationManager) -> an
209253
}
210254
};
211255

212-
let response = process_request(&line, &manager).await;
256+
let response = process_request(&line, &manager, lm).await;
213257
send_response(&mut writer, response).await?;
214258
Ok(())
215259
}
216260

217-
async fn process_request(line: &str, manager: &GenerationManager) -> Response {
261+
async fn process_request(
262+
line: &str,
263+
manager: &GenerationManager,
264+
lm: Option<Arc<Mutex<LmPlanner>>>,
265+
) -> Response {
218266
// Parse request.
219267
let req: Request = match serde_json::from_str(line) {
220268
Ok(r) => r,
@@ -225,11 +273,12 @@ async fn process_request(line: &str, manager: &GenerationManager) -> Response {
225273
if req.caption.trim().is_empty() {
226274
return Response::err("'caption' field is required and must not be empty");
227275
}
228-
if req.duration_s < 1.0 || req.duration_s > 600.0 {
229-
return Response::err(format!(
230-
"duration_s must be between 1 and 600, got {}",
231-
req.duration_s
232-
));
276+
if let Some(d) = req.duration_s {
277+
if d < 1.0 || d > 600.0 {
278+
return Response::err(format!(
279+
"duration_s must be between 1 and 600, got {d}"
280+
));
281+
}
233282
}
234283

235284
// Resolve output path.
@@ -272,12 +321,59 @@ async fn process_request(line: &str, manager: &GenerationManager) -> Response {
272321
}
273322
}
274323

324+
// Optionally run the LM planner to expand the caption into structured metadata.
325+
// Resolve duration: user value takes priority; LM suggestion used only when omitted.
326+
const DEFAULT_DURATION: f64 = 30.0;
327+
let user_duration = req.duration_s; // None = user did not specify
328+
329+
let (caption, metas, language, duration_s) =
330+
if let Some(lm_arc) = lm {
331+
let caption = req.caption.clone();
332+
let lyrics = req.lyrics.clone();
333+
let lm_fallback_duration = user_duration.unwrap_or(DEFAULT_DURATION);
334+
let result = tokio::task::spawn_blocking(move || {
335+
let mut planner = lm_arc.lock().unwrap();
336+
planner.plan(&caption, &lyrics, 512, 0.0)
337+
})
338+
.await;
339+
340+
match result {
341+
Ok(Ok(plan)) => {
342+
tracing::info!(
343+
bpm = ?plan.bpm,
344+
keyscale = ?plan.keyscale,
345+
language = ?plan.language,
346+
lm_duration_s = ?plan.duration_s,
347+
"LM planner output"
348+
);
349+
let metas = plan.to_metas_string(lm_fallback_duration);
350+
let caption = plan.caption.unwrap_or(req.caption);
351+
let language = plan.language.unwrap_or(req.language);
352+
// User-specified duration always wins; LM suggestion only if user omitted it.
353+
let duration_s = user_duration
354+
.or_else(|| plan.duration_s.map(|d| d as f64))
355+
.unwrap_or(DEFAULT_DURATION);
356+
(caption, metas, language, duration_s)
357+
}
358+
Ok(Err(e)) => {
359+
tracing::warn!("LM planner failed, falling back to raw caption: {e}");
360+
(req.caption, req.metas, req.language, user_duration.unwrap_or(DEFAULT_DURATION))
361+
}
362+
Err(e) => {
363+
tracing::warn!("LM planner task panicked, falling back: {e}");
364+
(req.caption, req.metas, req.language, user_duration.unwrap_or(DEFAULT_DURATION))
365+
}
366+
}
367+
} else {
368+
(req.caption, req.metas, req.language, user_duration.unwrap_or(DEFAULT_DURATION))
369+
};
370+
275371
let params = GenerationParams {
276-
caption: req.caption,
277-
metas: req.metas,
372+
caption,
373+
metas,
278374
lyrics: req.lyrics,
279-
language: req.language,
280-
duration_s: req.duration_s,
375+
language,
376+
duration_s,
281377
shift: req.shift,
282378
seed: req.seed,
283379
src_latents: None,
@@ -315,12 +411,7 @@ async fn process_request(line: &str, manager: &GenerationManager) -> Response {
315411

316412
tracing::info!(output = %output_path, "done");
317413

318-
Response::ok(
319-
output_path,
320-
req.duration_s,
321-
audio.sample_rate,
322-
audio.channels,
323-
)
414+
Response::ok(output_path, duration_s, audio.sample_rate, audio.channels)
324415
}
325416

326417
async fn send_response(

0 commit comments

Comments
 (0)