Skip to content

Commit acd3836

Browse files
author
Dongri Jin
authored
Merge pull request #15 from dongri/add-function-in-role
Add function in role
2 parents 4521588 + 2c6db93 commit acd3836

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

examples/function_call_role.rs

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
use openai_api_rs::v1::api::Client;
2+
use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest};
3+
use serde::{Deserialize, Serialize};
4+
use std::collections::HashMap;
5+
use std::{env, vec};
6+
7+
async fn get_coin_price(coin: &str) -> f64 {
8+
let coin = coin.to_lowercase();
9+
match coin.as_str() {
10+
"btc" | "bitcoin" => 10000.0,
11+
"eth" | "ethereum" => 1000.0,
12+
_ => 0.0,
13+
}
14+
}
15+
16+
#[tokio::main]
17+
async fn main() -> Result<(), Box<dyn std::error::Error>> {
18+
let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string());
19+
20+
let mut properties = HashMap::new();
21+
properties.insert(
22+
"coin".to_string(),
23+
Box::new(chat_completion::JSONSchemaDefine {
24+
schema_type: Some(chat_completion::JSONSchemaType::String),
25+
description: Some("The cryptocurrency to get the price of".to_string()),
26+
enum_values: None,
27+
properties: None,
28+
required: None,
29+
items: None,
30+
}),
31+
);
32+
33+
let req = ChatCompletionRequest {
34+
model: chat_completion::GPT3_5_TURBO_0613.to_string(),
35+
messages: vec![chat_completion::ChatCompletionMessage {
36+
role: chat_completion::MessageRole::user,
37+
content: Some(String::from("What is the price of Ethereum?")),
38+
name: None,
39+
function_call: None,
40+
}],
41+
functions: Some(vec![chat_completion::Function {
42+
name: String::from("get_coin_price"),
43+
description: Some(String::from("Get the price of a cryptocurrency")),
44+
parameters: Some(chat_completion::FunctionParameters {
45+
schema_type: chat_completion::JSONSchemaType::Object,
46+
properties: Some(properties),
47+
required: Some(vec![String::from("coin")]),
48+
}),
49+
}]),
50+
function_call: None,
51+
temperature: None,
52+
top_p: None,
53+
n: None,
54+
stream: None,
55+
stop: None,
56+
max_tokens: None,
57+
presence_penalty: None,
58+
frequency_penalty: None,
59+
logit_bias: None,
60+
user: None,
61+
};
62+
63+
let result = client.chat_completion(req).await?;
64+
65+
match result.choices[0].finish_reason {
66+
chat_completion::FinishReason::stop => {
67+
println!("Stop");
68+
println!("{:?}", result.choices[0].message.content);
69+
}
70+
chat_completion::FinishReason::length => {
71+
println!("Length");
72+
}
73+
chat_completion::FinishReason::function_call => {
74+
println!("FunctionCall");
75+
#[derive(Serialize, Deserialize)]
76+
struct Currency {
77+
coin: String,
78+
}
79+
let function_call = result.choices[0].message.function_call.as_ref().unwrap();
80+
let arguments = function_call.arguments.clone().unwrap();
81+
let c: Currency = serde_json::from_str(&arguments)?;
82+
let coin = c.coin;
83+
84+
let req = ChatCompletionRequest {
85+
model: chat_completion::GPT3_5_TURBO_0613.to_string(),
86+
messages: vec![chat_completion::ChatCompletionMessage {
87+
role: chat_completion::MessageRole::user,
88+
content: Some(String::from("What is the price of Ethereum?")),
89+
name: None,
90+
function_call: None,
91+
}, chat_completion::ChatCompletionMessage {
92+
role: chat_completion::MessageRole::function,
93+
content: Some({
94+
let price = get_coin_price(&coin).await;
95+
format!("{{\"price\": {}}}", price)
96+
}),
97+
name: Some(String::from("get_coin_price")),
98+
function_call: None,
99+
}],
100+
functions: None,
101+
function_call: None,
102+
temperature: None,
103+
top_p: None,
104+
n: None,
105+
stream: None,
106+
stop: None,
107+
max_tokens: None,
108+
presence_penalty: None,
109+
frequency_penalty: None,
110+
logit_bias: None,
111+
user: None,
112+
};
113+
let result = client.chat_completion(req).await?;
114+
println!("{:?}", result.choices[0].message.content);
115+
}
116+
chat_completion::FinishReason::content_filter => {
117+
println!("ContentFilter");
118+
}
119+
chat_completion::FinishReason::null => {
120+
println!("Null");
121+
}
122+
}
123+
Ok(())
124+
}
125+
126+
// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example function_call_role

src/v1/chat_completion.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ pub enum MessageRole {
4949
user,
5050
system,
5151
assistant,
52+
function,
5253
}
5354

5455
#[derive(Debug, Serialize, Deserialize)]

0 commit comments

Comments
 (0)