Skip to content

Commit ff77a60

Browse files
authored
[CHORE] Tool for managing tasks at the gRPC level. (#5576)
## Description of changes Expose the CRUD API for tasks over CLI. ## Test plan We use them and work with them as we need them. ## Migration plan N/A ## Observability plan N/A ## Documentation Changes N/A
1 parent ae2f4ca commit ff77a60

File tree

3 files changed

+233
-0
lines changed

3 files changed

+233
-0
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/sysdb/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ chrono = { workspace = true }
2525
prost = { workspace = true }
2626
prost-types = { workspace = true }
2727
derivative = "2.2.0"
28+
clap = { workspace = true }
2829

2930
chroma-config = { workspace = true }
3031
chroma-error = { workspace = true, features = ["tonic", "sqlx"] }
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
use chroma_types::chroma_proto;
2+
use clap::{Parser, Subcommand};
3+
use prost_types::value::Kind;
4+
use tonic::transport::Channel;
5+
6+
#[derive(Parser)]
7+
#[command(name = "chroma-sysdb")]
8+
#[command(about = "CLI client for Chroma coordinator task management", long_about = None)]
9+
struct Cli {
10+
#[arg(
11+
long,
12+
default_value = "http://localhost:50051",
13+
help = "Address of the Chroma coordinator service"
14+
)]
15+
addr: String,
16+
17+
#[command(subcommand)]
18+
command: Command,
19+
}
20+
21+
#[derive(Subcommand)]
22+
enum Command {
23+
#[command(about = "Create a new task")]
24+
CreateTask {
25+
#[arg(long, help = "Name of the task")]
26+
name: String,
27+
#[arg(long, help = "Name of the operator to apply")]
28+
operator_name: String,
29+
#[arg(long, help = "ID of the input collection")]
30+
input_collection_id: String,
31+
#[arg(long, help = "Name for the output collection")]
32+
output_collection_name: String,
33+
#[arg(long, help = "JSON object containing operator parameters")]
34+
params: String,
35+
#[arg(long, help = "Tenant ID")]
36+
tenant_id: String,
37+
#[arg(long, help = "Database name")]
38+
database: String,
39+
#[arg(
40+
long,
41+
default_value = "100",
42+
help = "Minimum number of records required before task execution"
43+
)]
44+
min_records_for_task: u64,
45+
},
46+
#[command(about = "Get task by name")]
47+
GetTask {
48+
#[arg(long, help = "ID of the input collection")]
49+
input_collection_id: String,
50+
#[arg(long, help = "Name of the task to retrieve")]
51+
task_name: String,
52+
},
53+
#[command(about = "Delete a task")]
54+
DeleteTask {
55+
#[arg(long, help = "ID of the input collection")]
56+
input_collection_id: String,
57+
#[arg(long, help = "Name of the task to delete")]
58+
task_name: String,
59+
#[arg(long, help = "Whether to delete the output collection")]
60+
delete_output: bool,
61+
},
62+
#[command(about = "Mark a task run as ready to advance")]
63+
AdvanceTask {
64+
#[arg(long, help = "ID of the collection")]
65+
collection_id: String,
66+
#[arg(long, help = "ID of the task")]
67+
task_id: String,
68+
#[arg(long, help = "Nonce identifying the specific task run")]
69+
task_run_nonce: String,
70+
},
71+
#[command(about = "Get all operators")]
72+
GetOperators,
73+
#[command(about = "Peek schedule by collection IDs")]
74+
PeekSchedule {
75+
#[arg(
76+
long,
77+
value_delimiter = ',',
78+
help = "Comma-separated list of collection IDs"
79+
)]
80+
collection_ids: Vec<String>,
81+
},
82+
}
83+
84+
fn json_to_prost_value(json: serde_json::Value) -> prost_types::Value {
85+
let kind = match json {
86+
serde_json::Value::Null => Kind::NullValue(0),
87+
serde_json::Value::Bool(b) => Kind::BoolValue(b),
88+
serde_json::Value::Number(n) => {
89+
if let Some(f) = n.as_f64() {
90+
Kind::NumberValue(f)
91+
} else {
92+
Kind::NullValue(0)
93+
}
94+
}
95+
serde_json::Value::String(s) => Kind::StringValue(s),
96+
serde_json::Value::Array(arr) => Kind::ListValue(prost_types::ListValue {
97+
values: arr.into_iter().map(json_to_prost_value).collect(),
98+
}),
99+
serde_json::Value::Object(map) => Kind::StructValue(prost_types::Struct {
100+
fields: map
101+
.into_iter()
102+
.map(|(k, v)| (k, json_to_prost_value(v)))
103+
.collect(),
104+
}),
105+
};
106+
prost_types::Value { kind: Some(kind) }
107+
}
108+
109+
#[tokio::main]
110+
async fn main() -> Result<(), Box<dyn std::error::Error>> {
111+
let cli = Cli::parse();
112+
113+
let channel = Channel::from_shared(cli.addr.clone())?.connect().await?;
114+
115+
let mut client = chroma_proto::sys_db_client::SysDbClient::new(channel);
116+
117+
match cli.command {
118+
Command::CreateTask {
119+
name,
120+
operator_name,
121+
input_collection_id,
122+
output_collection_name,
123+
params,
124+
tenant_id,
125+
database,
126+
min_records_for_task,
127+
} => {
128+
let params_json: serde_json::Value = serde_json::from_str(&params)?;
129+
let params_value = json_to_prost_value(params_json);
130+
let params_struct = match params_value.kind {
131+
Some(Kind::StructValue(s)) => Some(s),
132+
_ => {
133+
return Err("params must be a JSON object".into());
134+
}
135+
};
136+
137+
let request = chroma_proto::CreateTaskRequest {
138+
name,
139+
operator_name,
140+
input_collection_id,
141+
output_collection_name,
142+
params: params_struct,
143+
tenant_id,
144+
database,
145+
min_records_for_task,
146+
};
147+
148+
let response = client.create_task(request).await?;
149+
println!("Task created: {}", response.into_inner().task_id);
150+
}
151+
Command::GetTask {
152+
input_collection_id,
153+
task_name,
154+
} => {
155+
let request = chroma_proto::GetTaskByNameRequest {
156+
input_collection_id,
157+
task_name,
158+
};
159+
160+
let response = client.get_task_by_name(request).await?;
161+
let task = response.into_inner();
162+
163+
println!("Task ID: {:?}", task.task_id);
164+
println!("Name: {:?}", task.name);
165+
println!("Operator: {:?}", task.operator_name);
166+
println!("Input Collection: {:?}", task.input_collection_id);
167+
println!("Output Collection Name: {:?}", task.output_collection_name);
168+
println!("Output Collection ID: {:?}", task.output_collection_id);
169+
println!("Params: {:?}", task.params);
170+
println!("Completion Offset: {:?}", task.completion_offset);
171+
println!("Min Records: {:?}", task.min_records_for_task);
172+
}
173+
Command::DeleteTask {
174+
input_collection_id,
175+
task_name,
176+
delete_output,
177+
} => {
178+
let request = chroma_proto::DeleteTaskRequest {
179+
input_collection_id,
180+
task_name,
181+
delete_output,
182+
};
183+
184+
let response = client.delete_task(request).await?;
185+
println!("Task deleted: {}", response.into_inner().success);
186+
}
187+
Command::AdvanceTask {
188+
collection_id,
189+
task_id,
190+
task_run_nonce,
191+
} => {
192+
let request = chroma_proto::AdvanceTaskRequest {
193+
collection_id: Some(collection_id),
194+
task_id: Some(task_id),
195+
task_run_nonce: Some(task_run_nonce),
196+
};
197+
198+
client.advance_task(request).await?;
199+
println!("Task marked as done");
200+
}
201+
Command::GetOperators => {
202+
let request = chroma_proto::GetOperatorsRequest {};
203+
204+
let response = client.get_operators(request).await?;
205+
let operators = response.into_inner().operators;
206+
207+
for op in operators {
208+
println!(" {} - {}", op.id, op.name);
209+
}
210+
}
211+
Command::PeekSchedule { collection_ids } => {
212+
let request = chroma_proto::PeekScheduleByCollectionIdRequest {
213+
collection_id: collection_ids,
214+
};
215+
216+
let response = client.peek_schedule_by_collection_id(request).await?;
217+
let entries = response.into_inner().schedule;
218+
219+
println!("Schedule:");
220+
for entry in entries {
221+
println!(" Collection: {:?}", entry.collection_id);
222+
println!(" Task ID: {:?}", entry.task_id);
223+
println!(" Nonce: {:?}", entry.task_run_nonce);
224+
println!(" When: {:?}", entry.when_to_run);
225+
println!();
226+
}
227+
}
228+
}
229+
230+
Ok(())
231+
}

0 commit comments

Comments
 (0)