Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ src/
│ │ ├── fsq.rs — ResidualFsq (stub, cover mode only)
│ │ ├── pooler.rs — AttentionPooler (25Hz→5Hz)
│ │ └── detokenizer.rs — AudioTokenDetokenizer (5Hz→25Hz)
│ ├── lm_planner.rs — LmPlanner (5Hz LM CoT-only), PlannerOutput
│ └── generation.rs — AceStepConditionGenerationModel (ODE loop)
├── vae.rs — OobleckDecoder (Snake1d α+β, weight_norm)
└── pipeline.rs — end-to-end inference (stub)
Expand Down Expand Up @@ -155,24 +156,28 @@ Module roots (e.g., `src/audio.rs`) contain `mod` declarations and re-exports. N
**Socket:** `/tmp/ace-step-gen.sock` (default). Override with `--socket`.

**Protocol:**
- Request: `{"caption":"...", "output":"/tmp/out.ogg", "duration_s":30, ...}` + newline
- Request: `{"caption":"...", "output":"/tmp/out.ogg", "duration_s":30, "lyrics":"...", "metas":"...", "language":"en", "shift":3.0, "seed":null}` + newline
- Response success: `{"ok":true, "path":"...", "duration_s":30, "sample_rate":48000, "channels":2}` + newline
- Response error: `{"ok":false, "error":"..."}` + newline

**LM planner flag:** Pass `--use-lm` to load the 5Hz LM planner (1.7B, ~3.5GB VRAM). When active, the LM rewrites the caption and fills in BPM, key/scale, time signature, and language before the DiT runs. User-specified `duration_s` always takes priority; the LM suggestion is only used when duration is omitted. Recommended for best quality.

**Socket bind:** The socket is bound before model loading completes, so the socket file appears within ~1s of startup. Connections that arrive during loading queue in the channel and are served once the pipeline is ready.

**Build:**
```bash
LIBRARY_PATH=/usr/lib64:$LIBRARY_PATH PATH="/usr/local/cuda-12.4/bin:$PATH" \
CUDA_HOME=/usr/local/cuda-12.4 NVCC_CCBIN=/usr/bin/g++-13 CPLUS_INCLUDE_PATH="/tmp/cuda-shim" \
cargo build --release --example generation_daemon --features audio-ogg
cargo build --release --features audio-all
```

**Binary location:** `target/release/examples/generation_daemon`
**Binary location:** `target/release/generation-daemon`

**Systemd unit:** `~/.config/systemd/user/ace-step-gen.service` (enabled, auto-starts)

**Skill:** `~/.spacebot/skills/generate_music/SKILL.md` — instructs the worker to talk to the daemon socket via `socat`, with CLI binary fallback.

**Design rationale:** `stream_daemon` (also in this repo) is for live/continuous playback to speakers. `generation_daemon` is for one-shot file generation: user asks for a track, gets a file. The `GenerationManager` (`src/manager.rs`) handles the resident pipeline with OOM retry; `generation_daemon` just wraps it with a socket interface.
**Design rationale:** `generation_daemon` is for one-shot file generation: user asks for a track, gets a file. The `GenerationManager` (`src/manager.rs`) handles the resident pipeline with OOM retry; `generation_daemon` wraps it with a socket interface. The optional LM planner runs before each request (no separate socket — same process).

## Reference Repositories

Expand Down
157 changes: 130 additions & 27 deletions src/bin/generation-daemon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,16 @@
//! ```

use std::path::PathBuf;
use std::sync::{Arc, Mutex};

use ace_step_rs::{
audio::write_audio,
manager::{GenerationManager, ManagerConfig},
model::lm_planner::LmPlanner,
pipeline::GenerationParams,
};
use clap::Parser;
use hf_hub::api::sync::Api;
use serde::{Deserialize, Serialize};
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
Expand All @@ -69,6 +72,13 @@ struct Args {
/// CUDA device ordinal (0 = first GPU).
#[arg(long, default_value_t = 0)]
device: usize,

/// Load the 5Hz LM planner and use it to expand captions before generation.
///
/// Adds ~3.5GB VRAM usage. When enabled, the LM rewrites the caption and
/// fills in BPM, key/scale, time signature, and duration from the text.
#[arg(long, default_value_t = false)]
use_lm: bool,
}

// ── Wire types ───────────────────────────────────────────────────────────────
Expand All @@ -87,8 +97,9 @@ struct Request {
#[serde(default = "default_language")]
language: String,

#[serde(default = "default_duration")]
duration_s: f64,
/// Duration in seconds. When absent the LM planner may suggest one; otherwise defaults to 30.
#[serde(default)]
duration_s: Option<f64>,

#[serde(default = "default_shift")]
shift: f64,
Expand All @@ -105,9 +116,7 @@ struct Request {
fn default_language() -> String {
"en".into()
}
fn default_duration() -> f64 {
30.0
}

fn default_shift() -> f64 {
3.0
}
Expand Down Expand Up @@ -167,22 +176,53 @@ async fn main() -> anyhow::Result<()> {
std::fs::remove_file(&args.socket)?;
}

tracing::info!("Loading ACE-Step pipeline (this may take a minute on first run)...");
// Bind the socket immediately so callers can connect right away.
// Connections that arrive before loading completes will wait in the channel.
let listener = UnixListener::bind(&args.socket)?;
tracing::info!("Listening on {:?} (loading pipeline...)", args.socket);

// When the LM planner is resident it consumes ~3.5GB, so the pipeline
// itself only needs ~6.3GB — leave 512MB headroom instead of the default 2GB.
let min_free_vram_bytes = if args.use_lm {
512 * 1024 * 1024
} else {
ManagerConfig::default().min_free_vram_bytes
};
let config = ManagerConfig {
cuda_device: args.device,
min_free_vram_bytes,
..ManagerConfig::default()
};
let manager = GenerationManager::start(config).await?;
tracing::info!("Pipeline ready. Listening on {:?}", args.socket);

let listener = UnixListener::bind(&args.socket)?;
// Optionally load the LM planner (blocking, on the current thread).
let lm_planner: Option<Arc<Mutex<LmPlanner>>> = if args.use_lm {
tracing::info!("Loading 5Hz LM planner...");
let device = ace_step_rs::manager::preferred_device(args.device);
let planner = tokio::task::spawn_blocking(move || -> anyhow::Result<LmPlanner> {
let api = Api::new()?;
let repo = api.model("ACE-Step/Ace-Step1.5".to_string());
let weights = repo.get("acestep-5Hz-lm-1.7B/model.safetensors")?;
let tokenizer = repo.get("acestep-5Hz-lm-1.7B/tokenizer.json")?;
let planner = LmPlanner::load(&weights, &tokenizer, &device, candle_core::DType::BF16)?;
Ok(planner)
})
.await??;
tracing::info!("LM planner ready");
Some(Arc::new(Mutex::new(planner)))
} else {
None
};

tracing::info!("Pipeline ready");

loop {
match listener.accept().await {
Ok((stream, _addr)) => {
let manager = manager.clone();
let lm = lm_planner.clone();
tokio::spawn(async move {
if let Err(e) = handle_connection(stream, manager).await {
if let Err(e) = handle_connection(stream, manager, lm).await {
tracing::warn!("connection error: {e}");
}
});
Expand All @@ -196,7 +236,11 @@ async fn main() -> anyhow::Result<()> {

// ── Connection handler ────────────────────────────────────────────────────────

async fn handle_connection(stream: UnixStream, manager: GenerationManager) -> anyhow::Result<()> {
async fn handle_connection(
stream: UnixStream,
manager: GenerationManager,
lm: Option<Arc<Mutex<LmPlanner>>>,
) -> anyhow::Result<()> {
let (reader, mut writer) = stream.into_split();
let mut lines = BufReader::new(reader).lines();

Expand All @@ -209,12 +253,16 @@ async fn handle_connection(stream: UnixStream, manager: GenerationManager) -> an
}
};

let response = process_request(&line, &manager).await;
let response = process_request(&line, &manager, lm).await;
send_response(&mut writer, response).await?;
Ok(())
}

async fn process_request(line: &str, manager: &GenerationManager) -> Response {
async fn process_request(
line: &str,
manager: &GenerationManager,
lm: Option<Arc<Mutex<LmPlanner>>>,
) -> Response {
// Parse request.
let req: Request = match serde_json::from_str(line) {
Ok(r) => r,
Expand All @@ -225,11 +273,10 @@ async fn process_request(line: &str, manager: &GenerationManager) -> Response {
if req.caption.trim().is_empty() {
return Response::err("'caption' field is required and must not be empty");
}
if req.duration_s < 1.0 || req.duration_s > 600.0 {
return Response::err(format!(
"duration_s must be between 1 and 600, got {}",
req.duration_s
));
if let Some(d) = req.duration_s {
if d < 1.0 || d > 600.0 {
return Response::err(format!("duration_s must be between 1 and 600, got {d}"));
}
}

// Resolve output path.
Expand Down Expand Up @@ -272,12 +319,73 @@ async fn process_request(line: &str, manager: &GenerationManager) -> Response {
}
}

// Optionally run the LM planner to expand the caption into structured metadata.
// Resolve duration: user value takes priority; LM suggestion used only when omitted.
const DEFAULT_DURATION: f64 = 30.0;
let user_duration = req.duration_s; // None = user did not specify

let (caption, metas, language, duration_s) = if let Some(lm_arc) = lm {
let caption = req.caption.clone();
let lyrics = req.lyrics.clone();
let lm_fallback_duration = user_duration.unwrap_or(DEFAULT_DURATION);
let result = tokio::task::spawn_blocking(move || {
let mut planner = lm_arc.lock().unwrap();
planner.plan(&caption, &lyrics, 512, 0.0)
})
.await;

match result {
Ok(Ok(plan)) => {
tracing::info!(
bpm = ?plan.bpm,
keyscale = ?plan.keyscale,
language = ?plan.language,
lm_duration_s = ?plan.duration_s,
"LM planner output"
);
let metas = plan.to_metas_string(lm_fallback_duration);
let caption = plan.caption.unwrap_or(req.caption);
let language = plan.language.unwrap_or(req.language);
// User-specified duration always wins; LM suggestion only if user omitted it.
let duration_s = user_duration
.or_else(|| plan.duration_s.map(|d| d as f64))
.unwrap_or(DEFAULT_DURATION);
(caption, metas, language, duration_s)
}
Ok(Err(e)) => {
tracing::warn!("LM planner failed, falling back to raw caption: {e}");
(
req.caption,
req.metas,
req.language,
user_duration.unwrap_or(DEFAULT_DURATION),
)
}
Err(e) => {
tracing::warn!("LM planner task panicked, falling back: {e}");
(
req.caption,
req.metas,
req.language,
user_duration.unwrap_or(DEFAULT_DURATION),
)
}
}
} else {
(
req.caption,
req.metas,
req.language,
user_duration.unwrap_or(DEFAULT_DURATION),
)
};

let params = GenerationParams {
caption: req.caption,
metas: req.metas,
caption,
metas,
lyrics: req.lyrics,
language: req.language,
duration_s: req.duration_s,
language,
duration_s,
shift: req.shift,
seed: req.seed,
src_latents: None,
Expand Down Expand Up @@ -315,12 +423,7 @@ async fn process_request(line: &str, manager: &GenerationManager) -> Response {

tracing::info!(output = %output_path, "done");

Response::ok(
output_path,
req.duration_s,
audio.sample_rate,
audio.channels,
)
Response::ok(output_path, duration_s, audio.sample_rate, audio.channels)
}

async fn send_response(
Expand Down
2 changes: 2 additions & 0 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
//! - [`encoder`] — condition encoder (lyric encoder, timbre encoder, text projector)
//! - [`tokenizer`] — audio tokenizer (FSQ) and detokenizer
//! - [`generation`] — top-level generation model combining all components
//! - [`lm_planner`] — 5Hz LM planner (CoT-only), expands raw caption → structured metadata

pub mod encoder;
pub mod generation;
pub mod lm_planner;
pub mod tokenizer;
pub mod transformer;
Loading