|
| 1 | +//! Simple command-line client for the ACE-Step generation daemon. |
| 2 | +//! |
| 3 | +//! Connects to the Unix socket, sends a JSON generation request, waits for the |
| 4 | +//! response, and exits 0 on success or 1 on error. |
| 5 | +//! |
| 6 | +//! # Usage |
| 7 | +//! |
| 8 | +//! ```sh |
| 9 | +//! ace-step-client \ |
| 10 | +//! --caption "upbeat jazz, 120 BPM" \ |
| 11 | +//! --output /tmp/music.mp3 \ |
| 12 | +//! --duration 30 |
| 13 | +//! |
| 14 | +//! # With lyrics: |
| 15 | +//! ace-step-client \ |
| 16 | +//! --caption "silly novelty pop, bouncy" \ |
| 17 | +//! --lyrics "[verse]\nCat cat cat cat\n[chorus]\nMeow meow meow" \ |
| 18 | +//! --output /tmp/cat.mp3 \ |
| 19 | +//! --duration 30 |
| 20 | +//! |
| 21 | +//! # Unload pipeline to free VRAM: |
| 22 | +//! ace-step-client --unload |
| 23 | +//! ``` |
| 24 | +
|
| 25 | +use std::{path::PathBuf, time::Duration}; |
| 26 | + |
| 27 | +use anyhow::{Context, bail}; |
| 28 | +use clap::Parser; |
| 29 | +use serde::{Deserialize, Serialize}; |
| 30 | +use tokio::{ |
| 31 | + io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, |
| 32 | + net::UnixStream, |
| 33 | + time::timeout, |
| 34 | +}; |
| 35 | + |
| 36 | +#[derive(Parser)] |
| 37 | +#[command(name = "ace-step-client", about = "Send a generation request to the ACE-Step daemon")] |
| 38 | +struct Args { |
| 39 | + /// Style description: genre, mood, tempo, instruments |
| 40 | + #[arg(long)] |
| 41 | + caption: Option<String>, |
| 42 | + |
| 43 | + /// Output file path (.mp3, .ogg, or .wav) |
| 44 | + #[arg(long)] |
| 45 | + output: Option<PathBuf>, |
| 46 | + |
| 47 | + /// Duration in seconds (default: 30) |
| 48 | + #[arg(long, default_value = "30.0")] |
| 49 | + duration: f64, |
| 50 | + |
| 51 | + /// Lyrics with [verse]/[chorus]/[bridge] tags; omit for instrumental |
| 52 | + #[arg(long)] |
| 53 | + lyrics: Option<String>, |
| 54 | + |
| 55 | + /// Metadata string, e.g. "bpm: 120, key: C major" |
| 56 | + #[arg(long)] |
| 57 | + metas: Option<String>, |
| 58 | + |
| 59 | + /// Lyrics language code (default: en) |
| 60 | + #[arg(long, default_value = "en")] |
| 61 | + language: String, |
| 62 | + |
| 63 | + /// ODE schedule shift 1–3 (default: 3.0) |
| 64 | + #[arg(long, default_value = "3.0")] |
| 65 | + shift: f64, |
| 66 | + |
| 67 | + /// Fixed seed for reproducibility (omit for random) |
| 68 | + #[arg(long)] |
| 69 | + seed: Option<u64>, |
| 70 | + |
| 71 | + /// Socket path (default: /tmp/ace-step-gen.sock) |
| 72 | + #[arg(long, default_value = "/tmp/ace-step-gen.sock")] |
| 73 | + socket: PathBuf, |
| 74 | + |
| 75 | + /// Timeout in seconds to wait for generation (default: 300) |
| 76 | + #[arg(long, default_value = "300")] |
| 77 | + timeout_secs: u64, |
| 78 | + |
| 79 | + /// Unload the pipeline from VRAM instead of generating |
| 80 | + #[arg(long)] |
| 81 | + unload: bool, |
| 82 | +} |
| 83 | + |
| 84 | +#[derive(Serialize)] |
| 85 | +#[serde(untagged)] |
| 86 | +enum Request { |
| 87 | + Generate(GenerateRequest), |
| 88 | + Command(CommandRequest), |
| 89 | +} |
| 90 | + |
| 91 | +#[derive(Serialize)] |
| 92 | +struct GenerateRequest { |
| 93 | + caption: String, |
| 94 | + #[serde(skip_serializing_if = "Option::is_none")] |
| 95 | + output: Option<String>, |
| 96 | + duration_s: f64, |
| 97 | + #[serde(skip_serializing_if = "Option::is_none")] |
| 98 | + lyrics: Option<String>, |
| 99 | + #[serde(skip_serializing_if = "Option::is_none")] |
| 100 | + metas: Option<String>, |
| 101 | + language: String, |
| 102 | + shift: f64, |
| 103 | + #[serde(skip_serializing_if = "Option::is_none")] |
| 104 | + seed: Option<u64>, |
| 105 | +} |
| 106 | + |
| 107 | +#[derive(Serialize)] |
| 108 | +struct CommandRequest { |
| 109 | + command: String, |
| 110 | +} |
| 111 | + |
| 112 | +#[derive(Deserialize)] |
| 113 | +#[serde(untagged)] |
| 114 | +enum Response { |
| 115 | + Success(SuccessResponse), |
| 116 | + Error(ErrorResponse), |
| 117 | +} |
| 118 | + |
| 119 | +#[derive(Deserialize)] |
| 120 | +struct SuccessResponse { |
| 121 | + ok: bool, |
| 122 | + path: Option<String>, |
| 123 | + duration_s: Option<f64>, |
| 124 | +} |
| 125 | + |
| 126 | +#[derive(Deserialize)] |
| 127 | +struct ErrorResponse { |
| 128 | + #[allow(dead_code)] |
| 129 | + ok: bool, |
| 130 | + error: String, |
| 131 | +} |
| 132 | + |
| 133 | +#[tokio::main] |
| 134 | +async fn main() -> anyhow::Result<()> { |
| 135 | + let args = Args::parse(); |
| 136 | + |
| 137 | + let request = if args.unload { |
| 138 | + Request::Command(CommandRequest { command: "unload".into() }) |
| 139 | + } else { |
| 140 | + let caption = args.caption.context("--caption is required for generation")?; |
| 141 | + Request::Generate(GenerateRequest { |
| 142 | + caption, |
| 143 | + output: args.output.map(|p| p.to_string_lossy().into_owned()), |
| 144 | + duration_s: args.duration, |
| 145 | + lyrics: args.lyrics, |
| 146 | + metas: args.metas, |
| 147 | + language: args.language, |
| 148 | + shift: args.shift, |
| 149 | + seed: args.seed, |
| 150 | + }) |
| 151 | + }; |
| 152 | + |
| 153 | + let request_line = serde_json::to_string(&request)? + "\n"; |
| 154 | + |
| 155 | + let stream = timeout(Duration::from_secs(10), UnixStream::connect(&args.socket)) |
| 156 | + .await |
| 157 | + .context("timed out connecting to daemon socket")? |
| 158 | + .with_context(|| format!("failed to connect to {}", args.socket.display()))?; |
| 159 | + |
| 160 | + let (reader, mut writer) = stream.into_split(); |
| 161 | + |
| 162 | + writer |
| 163 | + .write_all(request_line.as_bytes()) |
| 164 | + .await |
| 165 | + .context("failed to send request")?; |
| 166 | + writer.flush().await?; |
| 167 | + // Signal EOF so the daemon knows we're done writing. |
| 168 | + drop(writer); |
| 169 | + |
| 170 | + let mut reader = BufReader::new(reader); |
| 171 | + let mut response_line = String::new(); |
| 172 | + |
| 173 | + timeout(Duration::from_secs(args.timeout_secs), reader.read_line(&mut response_line)) |
| 174 | + .await |
| 175 | + .context("timed out waiting for daemon response")? |
| 176 | + .context("failed to read response")?; |
| 177 | + |
| 178 | + if response_line.is_empty() { |
| 179 | + bail!("daemon closed connection without sending a response"); |
| 180 | + } |
| 181 | + |
| 182 | + let response: Response = |
| 183 | + serde_json::from_str(response_line.trim()).context("failed to parse daemon response")?; |
| 184 | + |
| 185 | + match response { |
| 186 | + Response::Success(r) if r.ok => { |
| 187 | + if let Some(path) = r.path { |
| 188 | + if let Some(duration) = r.duration_s { |
| 189 | + eprintln!("generated {:.1}s of audio → {path}", duration); |
| 190 | + } else { |
| 191 | + eprintln!("done → {path}"); |
| 192 | + } |
| 193 | + println!("{path}"); |
| 194 | + } else { |
| 195 | + eprintln!("ok"); |
| 196 | + } |
| 197 | + Ok(()) |
| 198 | + } |
| 199 | + Response::Success(r) => { |
| 200 | + bail!("daemon returned ok=false without error field (raw: {:?})", r.path); |
| 201 | + } |
| 202 | + Response::Error(r) => { |
| 203 | + bail!("generation failed: {}", r.error); |
| 204 | + } |
| 205 | + } |
| 206 | +} |
0 commit comments