Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 207 additions & 14 deletions crates/chat-cli/src/api_client/model.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
use std::collections::HashMap;

use aws_smithy_types::{
Blob,
Document,
Document as AwsDocument,
};
use serde::de::{
self,
MapAccess,
SeqAccess,
Visitor,
};
use serde::{
Deserialize,
Deserializer,
Serialize,
Serializer,
};

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
Expand Down Expand Up @@ -125,8 +135,189 @@ impl TryFrom<ChatMessage> for amzn_qdeveloper_streaming_client::types::ChatMessa
}
}

/// Information about a tool that can be used.
/// Wrapper around [aws_smithy_types::Document].
///
/// Used primarily so we can implement [Serialize] and [Deserialize] for
/// [aws_smith_types::Document].
#[derive(Debug, Clone)]
pub struct FigDocument(AwsDocument);

impl From<AwsDocument> for FigDocument {
fn from(value: AwsDocument) -> Self {
Self(value)
}
}

impl From<FigDocument> for AwsDocument {
fn from(value: FigDocument) -> Self {
value.0
}
}

/// Internal type used only during serialization for `FigDocument` to avoid unnecessary cloning.
#[derive(Debug, Clone)]
struct FigDocumentRef<'a>(&'a AwsDocument);

impl Serialize for FigDocumentRef<'_> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use aws_smithy_types::Number;
match self.0 {
AwsDocument::Null => serializer.serialize_unit(),
AwsDocument::Bool(b) => serializer.serialize_bool(*b),
AwsDocument::Number(n) => match n {
Number::PosInt(u) => serializer.serialize_u64(*u),
Number::NegInt(i) => serializer.serialize_i64(*i),
Number::Float(f) => serializer.serialize_f64(*f),
},
AwsDocument::String(s) => serializer.serialize_str(s),
AwsDocument::Array(arr) => {
use serde::ser::SerializeSeq;
let mut seq = serializer.serialize_seq(Some(arr.len()))?;
for value in arr {
seq.serialize_element(&Self(value))?;
}
seq.end()
},
AwsDocument::Object(m) => {
use serde::ser::SerializeMap;
let mut map = serializer.serialize_map(Some(m.len()))?;
for (k, v) in m {
map.serialize_entry(k, &Self(v))?;
}
map.end()
},
}
}
}

impl Serialize for FigDocument {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
FigDocumentRef(&self.0).serialize(serializer)
}
}

impl<'de> Deserialize<'de> for FigDocument {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
use aws_smithy_types::Number;

struct FigDocumentVisitor;

impl<'de> Visitor<'de> for FigDocumentVisitor {
type Value = FigDocument;

fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("any valid JSON value")
}

fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(FigDocument(AwsDocument::Bool(value)))
}

fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(FigDocument(AwsDocument::Number(if value < 0 {
Number::NegInt(value)
} else {
Number::PosInt(value as u64)
})))
}

fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(FigDocument(AwsDocument::Number(Number::PosInt(value))))
}

fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(FigDocument(AwsDocument::Number(Number::Float(value))))
}

fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(FigDocument(AwsDocument::String(value.to_owned())))
}

fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(FigDocument(AwsDocument::String(value)))
}

fn visit_none<E>(self) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(FigDocument(AwsDocument::Null))
}

fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
Deserialize::deserialize(deserializer)
}

fn visit_unit<E>(self) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(FigDocument(AwsDocument::Null))
}

fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut vec = Vec::new();

while let Some(elem) = seq.next_element::<FigDocument>()? {
vec.push(elem.0);
}

Ok(FigDocument(AwsDocument::Array(vec)))
}

fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
where
M: MapAccess<'de>,
{
let mut map = HashMap::new();

while let Some((key, value)) = access.next_entry::<String, FigDocument>()? {
map.insert(key, value.0);
}

Ok(FigDocument(AwsDocument::Object(map)))
}
}

deserializer.deserialize_any(FigDocumentVisitor)
}
}

/// Information about a tool that can be used.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Tool {
ToolSpecification(ToolSpecification),
}
Expand All @@ -148,7 +339,7 @@ impl From<Tool> for amzn_qdeveloper_streaming_client::types::Tool {
}

/// The specification for the tool.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolSpecification {
/// The name for the tool.
pub name: String,
Expand Down Expand Up @@ -181,41 +372,41 @@ impl From<ToolSpecification> for amzn_qdeveloper_streaming_client::types::ToolSp
}

/// The input schema for the tool in JSON format.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolInputSchema {
pub json: Option<Document>,
pub json: Option<FigDocument>,
}

impl From<ToolInputSchema> for amzn_codewhisperer_streaming_client::types::ToolInputSchema {
fn from(value: ToolInputSchema) -> Self {
Self::builder().set_json(value.json).build()
Self::builder().set_json(value.json.map(Into::into)).build()
}
}

impl From<ToolInputSchema> for amzn_qdeveloper_streaming_client::types::ToolInputSchema {
fn from(value: ToolInputSchema) -> Self {
Self::builder().set_json(value.json).build()
Self::builder().set_json(value.json.map(Into::into)).build()
}
}

/// Contains information about a tool that the model is requesting be run. The model uses the result
/// from the tool to generate a response.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolUse {
/// The ID for the tool request.
pub tool_use_id: String,
/// The name for the tool.
pub name: String,
/// The input to pass to the tool.
pub input: Document,
pub input: FigDocument,
}

impl From<ToolUse> for amzn_codewhisperer_streaming_client::types::ToolUse {
fn from(value: ToolUse) -> Self {
Self::builder()
.tool_use_id(value.tool_use_id)
.name(value.name)
.input(value.input)
.input(value.input.into())
.build()
.expect("building ToolUse should not fail")
}
Expand All @@ -226,7 +417,7 @@ impl From<ToolUse> for amzn_qdeveloper_streaming_client::types::ToolUse {
Self::builder()
.tool_use_id(value.tool_use_id)
.name(value.name)
.input(value.input)
.input(value.input.into())
.build()
.expect("building ToolUse should not fail")
}
Expand Down Expand Up @@ -268,7 +459,7 @@ impl From<ToolResult> for amzn_qdeveloper_streaming_client::types::ToolResult {
#[derive(Debug, Clone)]
pub enum ToolResultContentBlock {
/// A tool result that is JSON format data.
Json(Document),
Json(AwsDocument),
/// A tool result that is text.
Text(String),
}
Expand Down Expand Up @@ -780,7 +971,7 @@ mod tests {
name: "test tool name".to_string(),
description: "test tool description".to_string(),
input_schema: ToolInputSchema {
json: Some(Document::Null),
json: Some(AwsDocument::Null.into()),
},
})]),
}),
Expand Down Expand Up @@ -814,7 +1005,9 @@ mod tests {
tool_uses: Some(vec![ToolUse {
tool_use_id: "tooluseid_test".to_string(),
name: "tool_name_test".to_string(),
input: Document::Object([("key1".to_string(), Document::Null)].into_iter().collect()),
input: FigDocument(AwsDocument::Object(
[("key1".to_string(), AwsDocument::Null)].into_iter().collect(),
)),
}]),
};
let codewhisper_input =
Expand Down
26 changes: 26 additions & 0 deletions crates/chat-cli/src/cli/chat/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ pub enum Command {
subcommand: Option<PromptsSubcommand>,
},
Usage,
Import {
path: String,
},
Export {
path: String,
force: bool,
},
}

#[derive(Debug, Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -811,6 +818,25 @@ impl Command {
}
},
"usage" => Self::Usage,
"import" => {
let Some(path) = parts.get(1) else {
return Err("path is required".to_string());
};
Self::Import {
path: (*path).to_string(),
}
},
"export" => {
let force = parts.contains(&"-f") || parts.contains(&"--force");
let Some(path) = parts.get(1) else {
return Err("path is required".to_string());
};
let mut path = (*path).to_string();
if !path.ends_with(".json") {
path.push_str(".json");
}
Self::Export { path, force }
},
unknown_command => {
let looks_like_path = {
let after_slash_command_str = parts[1..].join(" ");
Expand Down
8 changes: 7 additions & 1 deletion crates/chat-cli/src/cli/chat/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ pub struct ContextConfig {

#[allow(dead_code)]
/// Manager for context files and profiles.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextManager {
#[serde(skip)]
#[serde(default = "default_context")]
ctx: Arc<Context>,

max_context_files_size: usize,
Expand Down Expand Up @@ -798,6 +800,10 @@ fn validate_profile_name(name: &str) -> Result<()> {
Ok(())
}

fn default_context() -> Arc<Context> {
Context::new()
}

#[cfg(test)]
mod tests {
use std::io::Stdout;
Expand Down
Loading
Loading