4141//! ```
4242
4343use std:: path:: PathBuf ;
44+ use std:: sync:: { Arc , Mutex } ;
4445
4546use ace_step_rs:: {
4647 audio:: write_audio,
4748 manager:: { GenerationManager , ManagerConfig } ,
49+ model:: lm_planner:: LmPlanner ,
4850 pipeline:: GenerationParams ,
4951} ;
5052use clap:: Parser ;
53+ use hf_hub:: api:: sync:: Api ;
5154use serde:: { Deserialize , Serialize } ;
5255use tokio:: {
5356 io:: { AsyncBufReadExt , AsyncWriteExt , BufReader } ,
@@ -69,6 +72,13 @@ struct Args {
6972 /// CUDA device ordinal (0 = first GPU).
7073 #[ arg( long, default_value_t = 0 ) ]
7174 device : usize ,
75+
76+ /// Load the 5Hz LM planner and use it to expand captions before generation.
77+ ///
78+ /// Adds ~3.5GB VRAM usage. When enabled, the LM rewrites the caption and
79+ /// fills in BPM, key/scale, time signature, and duration from the text.
80+ #[ arg( long, default_value_t = false ) ]
81+ use_lm : bool ,
7282}
7383
7484// ── Wire types ───────────────────────────────────────────────────────────────
@@ -87,8 +97,9 @@ struct Request {
8797 #[ serde( default = "default_language" ) ]
8898 language : String ,
8999
90- #[ serde( default = "default_duration" ) ]
91- duration_s : f64 ,
100+ /// Duration in seconds. When absent the LM planner may suggest one; otherwise defaults to 30.
101+ #[ serde( default ) ]
102+ duration_s : Option < f64 > ,
92103
93104 #[ serde( default = "default_shift" ) ]
94105 shift : f64 ,
@@ -105,9 +116,7 @@ struct Request {
105116fn default_language ( ) -> String {
106117 "en" . into ( )
107118}
108- fn default_duration ( ) -> f64 {
109- 30.0
110- }
119+
111120fn default_shift ( ) -> f64 {
112121 3.0
113122}
@@ -167,22 +176,53 @@ async fn main() -> anyhow::Result<()> {
167176 std:: fs:: remove_file ( & args. socket ) ?;
168177 }
169178
170- tracing:: info!( "Loading ACE-Step pipeline (this may take a minute on first run)..." ) ;
179+ // Bind the socket immediately so callers can connect right away.
180+ // Connections that arrive before loading completes will wait in the channel.
181+ let listener = UnixListener :: bind ( & args. socket ) ?;
182+ tracing:: info!( "Listening on {:?} (loading pipeline...)" , args. socket) ;
183+
184+ // When the LM planner is resident it consumes ~3.5GB, so the pipeline
185+ // itself only needs ~6.3GB — leave 512MB headroom instead of the default 2GB.
186+ let min_free_vram_bytes = if args. use_lm {
187+ 512 * 1024 * 1024
188+ } else {
189+ ManagerConfig :: default ( ) . min_free_vram_bytes
190+ } ;
171191 let config = ManagerConfig {
172192 cuda_device : args. device ,
193+ min_free_vram_bytes,
173194 ..ManagerConfig :: default ( )
174195 } ;
175196 let manager = GenerationManager :: start ( config) . await ?;
176- tracing:: info!( "Pipeline ready. Listening on {:?}" , args. socket) ;
177197
178- let listener = UnixListener :: bind ( & args. socket ) ?;
198+ // Optionally load the LM planner (blocking, on the current thread).
199+ let lm_planner: Option < Arc < Mutex < LmPlanner > > > = if args. use_lm {
200+ tracing:: info!( "Loading 5Hz LM planner..." ) ;
201+ let device = ace_step_rs:: manager:: preferred_device ( args. device ) ;
202+ let planner = tokio:: task:: spawn_blocking ( move || -> anyhow:: Result < LmPlanner > {
203+ let api = Api :: new ( ) ?;
204+ let repo = api. model ( "ACE-Step/Ace-Step1.5" . to_string ( ) ) ;
205+ let weights = repo. get ( "acestep-5Hz-lm-1.7B/model.safetensors" ) ?;
206+ let tokenizer = repo. get ( "acestep-5Hz-lm-1.7B/tokenizer.json" ) ?;
207+ let planner = LmPlanner :: load ( & weights, & tokenizer, & device, candle_core:: DType :: BF16 ) ?;
208+ Ok ( planner)
209+ } )
210+ . await ??;
211+ tracing:: info!( "LM planner ready" ) ;
212+ Some ( Arc :: new ( Mutex :: new ( planner) ) )
213+ } else {
214+ None
215+ } ;
216+
217+ tracing:: info!( "Pipeline ready" ) ;
179218
180219 loop {
181220 match listener. accept ( ) . await {
182221 Ok ( ( stream, _addr) ) => {
183222 let manager = manager. clone ( ) ;
223+ let lm = lm_planner. clone ( ) ;
184224 tokio:: spawn ( async move {
185- if let Err ( e) = handle_connection ( stream, manager) . await {
225+ if let Err ( e) = handle_connection ( stream, manager, lm ) . await {
186226 tracing:: warn!( "connection error: {e}" ) ;
187227 }
188228 } ) ;
@@ -196,7 +236,11 @@ async fn main() -> anyhow::Result<()> {
196236
197237// ── Connection handler ────────────────────────────────────────────────────────
198238
199- async fn handle_connection ( stream : UnixStream , manager : GenerationManager ) -> anyhow:: Result < ( ) > {
239+ async fn handle_connection (
240+ stream : UnixStream ,
241+ manager : GenerationManager ,
242+ lm : Option < Arc < Mutex < LmPlanner > > > ,
243+ ) -> anyhow:: Result < ( ) > {
200244 let ( reader, mut writer) = stream. into_split ( ) ;
201245 let mut lines = BufReader :: new ( reader) . lines ( ) ;
202246
@@ -209,12 +253,16 @@ async fn handle_connection(stream: UnixStream, manager: GenerationManager) -> an
209253 }
210254 } ;
211255
212- let response = process_request ( & line, & manager) . await ;
256+ let response = process_request ( & line, & manager, lm ) . await ;
213257 send_response ( & mut writer, response) . await ?;
214258 Ok ( ( ) )
215259}
216260
217- async fn process_request ( line : & str , manager : & GenerationManager ) -> Response {
261+ async fn process_request (
262+ line : & str ,
263+ manager : & GenerationManager ,
264+ lm : Option < Arc < Mutex < LmPlanner > > > ,
265+ ) -> Response {
218266 // Parse request.
219267 let req: Request = match serde_json:: from_str ( line) {
220268 Ok ( r) => r,
@@ -225,11 +273,12 @@ async fn process_request(line: &str, manager: &GenerationManager) -> Response {
225273 if req. caption . trim ( ) . is_empty ( ) {
226274 return Response :: err ( "'caption' field is required and must not be empty" ) ;
227275 }
228- if req. duration_s < 1.0 || req. duration_s > 600.0 {
229- return Response :: err ( format ! (
230- "duration_s must be between 1 and 600, got {}" ,
231- req. duration_s
232- ) ) ;
276+ if let Some ( d) = req. duration_s {
277+ if d < 1.0 || d > 600.0 {
278+ return Response :: err ( format ! (
279+ "duration_s must be between 1 and 600, got {d}"
280+ ) ) ;
281+ }
233282 }
234283
235284 // Resolve output path.
@@ -272,12 +321,59 @@ async fn process_request(line: &str, manager: &GenerationManager) -> Response {
272321 }
273322 }
274323
324+ // Optionally run the LM planner to expand the caption into structured metadata.
325+ // Resolve duration: user value takes priority; LM suggestion used only when omitted.
326+ const DEFAULT_DURATION : f64 = 30.0 ;
327+ let user_duration = req. duration_s ; // None = user did not specify
328+
329+ let ( caption, metas, language, duration_s) =
330+ if let Some ( lm_arc) = lm {
331+ let caption = req. caption . clone ( ) ;
332+ let lyrics = req. lyrics . clone ( ) ;
333+ let lm_fallback_duration = user_duration. unwrap_or ( DEFAULT_DURATION ) ;
334+ let result = tokio:: task:: spawn_blocking ( move || {
335+ let mut planner = lm_arc. lock ( ) . unwrap ( ) ;
336+ planner. plan ( & caption, & lyrics, 512 , 0.0 )
337+ } )
338+ . await ;
339+
340+ match result {
341+ Ok ( Ok ( plan) ) => {
342+ tracing:: info!(
343+ bpm = ?plan. bpm,
344+ keyscale = ?plan. keyscale,
345+ language = ?plan. language,
346+ lm_duration_s = ?plan. duration_s,
347+ "LM planner output"
348+ ) ;
349+ let metas = plan. to_metas_string ( lm_fallback_duration) ;
350+ let caption = plan. caption . unwrap_or ( req. caption ) ;
351+ let language = plan. language . unwrap_or ( req. language ) ;
352+ // User-specified duration always wins; LM suggestion only if user omitted it.
353+ let duration_s = user_duration
354+ . or_else ( || plan. duration_s . map ( |d| d as f64 ) )
355+ . unwrap_or ( DEFAULT_DURATION ) ;
356+ ( caption, metas, language, duration_s)
357+ }
358+ Ok ( Err ( e) ) => {
359+ tracing:: warn!( "LM planner failed, falling back to raw caption: {e}" ) ;
360+ ( req. caption , req. metas , req. language , user_duration. unwrap_or ( DEFAULT_DURATION ) )
361+ }
362+ Err ( e) => {
363+ tracing:: warn!( "LM planner task panicked, falling back: {e}" ) ;
364+ ( req. caption , req. metas , req. language , user_duration. unwrap_or ( DEFAULT_DURATION ) )
365+ }
366+ }
367+ } else {
368+ ( req. caption , req. metas , req. language , user_duration. unwrap_or ( DEFAULT_DURATION ) )
369+ } ;
370+
275371 let params = GenerationParams {
276- caption : req . caption ,
277- metas : req . metas ,
372+ caption,
373+ metas,
278374 lyrics : req. lyrics ,
279- language : req . language ,
280- duration_s : req . duration_s ,
375+ language,
376+ duration_s,
281377 shift : req. shift ,
282378 seed : req. seed ,
283379 src_latents : None ,
@@ -315,12 +411,7 @@ async fn process_request(line: &str, manager: &GenerationManager) -> Response {
315411
316412 tracing:: info!( output = %output_path, "done" ) ;
317413
318- Response :: ok (
319- output_path,
320- req. duration_s ,
321- audio. sample_rate ,
322- audio. channels ,
323- )
414+ Response :: ok ( output_path, duration_s, audio. sample_rate , audio. channels )
324415}
325416
326417async fn send_response (
0 commit comments