Skip to content

Commit 2e2e2b0

Browse files
authored
Merge pull request #5 from Marenz/add-client
Add ace-step-client: simple CLI for the generation daemon
2 parents c758e0e + 3d9579a commit 2e2e2b0

File tree

2 files changed

+210
-0
lines changed

2 files changed

+210
-0
lines changed

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,7 @@ required-features = ["cli"]
9797
name = "generation-daemon"
9898
path = "src/bin/generation-daemon.rs"
9999
required-features = ["audio-ogg"]
100+
101+
[[bin]]
102+
name = "ace-step-client"
103+
path = "src/bin/ace-step-client.rs"

src/bin/ace-step-client.rs

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
//! Simple command-line client for the ACE-Step generation daemon.
2+
//!
3+
//! Connects to the Unix socket, sends a JSON generation request, waits for the
4+
//! response, and exits 0 on success or 1 on error.
5+
//!
6+
//! # Usage
7+
//!
8+
//! ```sh
9+
//! ace-step-client \
10+
//! --caption "upbeat jazz, 120 BPM" \
11+
//! --output /tmp/music.mp3 \
12+
//! --duration 30
13+
//!
14+
//! # With lyrics:
15+
//! ace-step-client \
16+
//! --caption "silly novelty pop, bouncy" \
17+
//! --lyrics "[verse]\nCat cat cat cat\n[chorus]\nMeow meow meow" \
18+
//! --output /tmp/cat.mp3 \
19+
//! --duration 30
20+
//!
21+
//! # Unload pipeline to free VRAM:
22+
//! ace-step-client --unload
23+
//! ```
24+
25+
use std::{path::PathBuf, time::Duration};
26+
27+
use anyhow::{Context, bail};
28+
use clap::Parser;
29+
use serde::{Deserialize, Serialize};
30+
use tokio::{
31+
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
32+
net::UnixStream,
33+
time::timeout,
34+
};
35+
36+
#[derive(Parser)]
37+
#[command(name = "ace-step-client", about = "Send a generation request to the ACE-Step daemon")]
38+
struct Args {
39+
/// Style description: genre, mood, tempo, instruments
40+
#[arg(long)]
41+
caption: Option<String>,
42+
43+
/// Output file path (.mp3, .ogg, or .wav)
44+
#[arg(long)]
45+
output: Option<PathBuf>,
46+
47+
/// Duration in seconds (default: 30)
48+
#[arg(long, default_value = "30.0")]
49+
duration: f64,
50+
51+
/// Lyrics with [verse]/[chorus]/[bridge] tags; omit for instrumental
52+
#[arg(long)]
53+
lyrics: Option<String>,
54+
55+
/// Metadata string, e.g. "bpm: 120, key: C major"
56+
#[arg(long)]
57+
metas: Option<String>,
58+
59+
/// Lyrics language code (default: en)
60+
#[arg(long, default_value = "en")]
61+
language: String,
62+
63+
/// ODE schedule shift 1–3 (default: 3.0)
64+
#[arg(long, default_value = "3.0")]
65+
shift: f64,
66+
67+
/// Fixed seed for reproducibility (omit for random)
68+
#[arg(long)]
69+
seed: Option<u64>,
70+
71+
/// Socket path (default: /tmp/ace-step-gen.sock)
72+
#[arg(long, default_value = "/tmp/ace-step-gen.sock")]
73+
socket: PathBuf,
74+
75+
/// Timeout in seconds to wait for generation (default: 300)
76+
#[arg(long, default_value = "300")]
77+
timeout_secs: u64,
78+
79+
/// Unload the pipeline from VRAM instead of generating
80+
#[arg(long)]
81+
unload: bool,
82+
}
83+
84+
#[derive(Serialize)]
85+
#[serde(untagged)]
86+
enum Request {
87+
Generate(GenerateRequest),
88+
Command(CommandRequest),
89+
}
90+
91+
#[derive(Serialize)]
92+
struct GenerateRequest {
93+
caption: String,
94+
#[serde(skip_serializing_if = "Option::is_none")]
95+
output: Option<String>,
96+
duration_s: f64,
97+
#[serde(skip_serializing_if = "Option::is_none")]
98+
lyrics: Option<String>,
99+
#[serde(skip_serializing_if = "Option::is_none")]
100+
metas: Option<String>,
101+
language: String,
102+
shift: f64,
103+
#[serde(skip_serializing_if = "Option::is_none")]
104+
seed: Option<u64>,
105+
}
106+
107+
#[derive(Serialize)]
108+
struct CommandRequest {
109+
command: String,
110+
}
111+
112+
#[derive(Deserialize)]
113+
#[serde(untagged)]
114+
enum Response {
115+
Success(SuccessResponse),
116+
Error(ErrorResponse),
117+
}
118+
119+
#[derive(Deserialize)]
120+
struct SuccessResponse {
121+
ok: bool,
122+
path: Option<String>,
123+
duration_s: Option<f64>,
124+
}
125+
126+
#[derive(Deserialize)]
127+
struct ErrorResponse {
128+
#[allow(dead_code)]
129+
ok: bool,
130+
error: String,
131+
}
132+
133+
#[tokio::main]
134+
async fn main() -> anyhow::Result<()> {
135+
let args = Args::parse();
136+
137+
let request = if args.unload {
138+
Request::Command(CommandRequest { command: "unload".into() })
139+
} else {
140+
let caption = args.caption.context("--caption is required for generation")?;
141+
Request::Generate(GenerateRequest {
142+
caption,
143+
output: args.output.map(|p| p.to_string_lossy().into_owned()),
144+
duration_s: args.duration,
145+
lyrics: args.lyrics,
146+
metas: args.metas,
147+
language: args.language,
148+
shift: args.shift,
149+
seed: args.seed,
150+
})
151+
};
152+
153+
let request_line = serde_json::to_string(&request)? + "\n";
154+
155+
let stream = timeout(Duration::from_secs(10), UnixStream::connect(&args.socket))
156+
.await
157+
.context("timed out connecting to daemon socket")?
158+
.with_context(|| format!("failed to connect to {}", args.socket.display()))?;
159+
160+
let (reader, mut writer) = stream.into_split();
161+
162+
writer
163+
.write_all(request_line.as_bytes())
164+
.await
165+
.context("failed to send request")?;
166+
writer.flush().await?;
167+
// Signal EOF so the daemon knows we're done writing.
168+
drop(writer);
169+
170+
let mut reader = BufReader::new(reader);
171+
let mut response_line = String::new();
172+
173+
timeout(Duration::from_secs(args.timeout_secs), reader.read_line(&mut response_line))
174+
.await
175+
.context("timed out waiting for daemon response")?
176+
.context("failed to read response")?;
177+
178+
if response_line.is_empty() {
179+
bail!("daemon closed connection without sending a response");
180+
}
181+
182+
let response: Response =
183+
serde_json::from_str(response_line.trim()).context("failed to parse daemon response")?;
184+
185+
match response {
186+
Response::Success(r) if r.ok => {
187+
if let Some(path) = r.path {
188+
if let Some(duration) = r.duration_s {
189+
eprintln!("generated {:.1}s of audio → {path}", duration);
190+
} else {
191+
eprintln!("done → {path}");
192+
}
193+
println!("{path}");
194+
} else {
195+
eprintln!("ok");
196+
}
197+
Ok(())
198+
}
199+
Response::Success(r) => {
200+
bail!("daemon returned ok=false without error field (raw: {:?})", r.path);
201+
}
202+
Response::Error(r) => {
203+
bail!("generation failed: {}", r.error);
204+
}
205+
}
206+
}

0 commit comments

Comments
 (0)