Skip to content

Commit bc9ef05

Browse files
committed
Progress on chatbot kernel
1 parent 854585d commit bc9ef05

File tree

6 files changed

+184
-5
lines changed

6 files changed

+184
-5
lines changed

MODULE.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ bazel_dep(name = "rules_python", version = "0.37.2")
1414
bazel_dep(name = "platforms", version = "0.0.10")
1515
bazel_dep(name = "googletest", version = "1.15.2")
1616
bazel_dep(name = "apple_support", version = "1.17.1", repo_name = "build_bazel_apple_support")
17+
bazel_dep(name = "curl", version = "8.8.0")
18+
bazel_dep(name = "nlohmann_json", version = "3.11.3")
1719

1820
# Use archive_override to patch rules_foreign_cc to default to specific cmake version
1921
archive_override(

MODULE.bazel.lock

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

kernels/ai_server/BUILD

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
cc_library(
2+
name = "llm_kernels",
3+
srcs = glob([
4+
"*.cpp",
5+
]),
6+
hdrs = glob([
7+
"*.h",
8+
"*.hpp",
9+
]),
10+
includes = [
11+
".",
12+
"//framework/include"
13+
],
14+
deps = [
15+
"//:corevx",
16+
"@curl//:curl",
17+
"@nlohmann_json//:json"
18+
],
19+
visibility = ["//visibility:public"]
20+
)

kernels/ai_server/chatbot.hpp

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/**
2+
* @file chatbot.hpp
3+
* @brief Kernel for AI Model Server Chatbot
4+
* @version 0.1
5+
* @date 2025-04-04
6+
*
7+
* @copyright Copyright (c) 2025
8+
*
9+
*/
10+
#include <curl/curl.h>
11+
#include <nlohmann/json.hpp>
12+
#include <string>
13+
#include <vector>
14+
#include <VX/vx.h>
15+
16+
#define DEFAULT_MODEL "gpt-4o-mini"
17+
#define SERVER_URL "http://localhost:8000"
18+
#define API_KEY "hardcoded-api-key"
19+
20+
class RemoteModelClient
21+
{
22+
private:
23+
// Helper function for non-streaming response
24+
static size_t WriteCallback(void *contents, size_t size, size_t nmemb, void *userp)
25+
{
26+
size_t totalSize = size * nmemb;
27+
((std::string *)userp)->append((char *)contents, totalSize);
28+
return totalSize;
29+
}
30+
31+
public:
32+
// kernel function (non-streaming)
33+
vx_status AiServerQuery(const std::string &input_text, std::string &output_text, const std::string &api_path)
34+
{
35+
CURL *curl = curl_easy_init();
36+
if (!curl)
37+
return VX_FAILURE;
38+
39+
nlohmann::json request_json = {
40+
{"model", DEFAULT_MODEL},
41+
{"messages", {{{"role", "user"}, {"content", input_text}}}},
42+
{"max_tokens", 100},
43+
{"stream", false}};
44+
45+
std::string request_payload = request_json.dump();
46+
std::string response_string;
47+
std::string api_url = std::string(SERVER_URL) + api_path;
48+
49+
struct curl_slist *headers = nullptr;
50+
headers = curl_slist_append(headers, "Content-Type: application/json");
51+
headers = curl_slist_append(headers, ("Authorization: Bearer " + std::string(API_KEY)).c_str());
52+
53+
curl_easy_setopt(curl, CURLOPT_URL, api_url.c_str());
54+
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, request_payload.c_str());
55+
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);
56+
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback);
57+
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response_string);
58+
59+
CURLcode res = curl_easy_perform(curl);
60+
curl_slist_free_all(headers);
61+
curl_easy_cleanup(curl);
62+
63+
if (res != CURLE_OK)
64+
return VX_FAILURE;
65+
66+
auto json_response = nlohmann::json::parse(response_string);
67+
output_text = json_response["choices"][0]["message"]["content"];
68+
69+
return VX_SUCCESS;
70+
}
71+
72+
// kernel function (streaming)
73+
vx_status AiServerQueryStream(const std::string &input_text, std::string &output_text, const std::string &api_path)
74+
{
75+
CURL *curl = curl_easy_init();
76+
if (!curl)
77+
return VX_FAILURE;
78+
79+
nlohmann::json request_json = {
80+
{"model", DEFAULT_MODEL},
81+
{"messages", {{{"role", "user"}, {"content", input_text}}}},
82+
{"max_tokens", 100},
83+
{"stream", true}};
84+
85+
std::string request_payload = request_json.dump();
86+
std::string response_chunk;
87+
std::string api_url = std::string(SERVER_URL) + api_path;
88+
89+
struct curl_slist *headers = nullptr;
90+
headers = curl_slist_append(headers, "Content-Type: application/json");
91+
headers = curl_slist_append(headers, ("Authorization: Bearer " + std::string(API_KEY)).c_str());
92+
93+
curl_easy_setopt(curl, CURLOPT_URL, api_url.c_str());
94+
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, request_payload.c_str());
95+
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);
96+
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback);
97+
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response_chunk);
98+
99+
CURLcode res = curl_easy_perform(curl);
100+
curl_slist_free_all(headers);
101+
curl_easy_cleanup(curl);
102+
103+
if (res != CURLE_OK)
104+
return VX_FAILURE;
105+
106+
// Just return raw streamed response (newline-delimited JSON chunks)
107+
output_text = response_chunk;
108+
return VX_SUCCESS;
109+
}
110+
};

targets/ai_server/BUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ cc_library(
88
includes = [
99
".",
1010
"//framework/include",
11-
# "//kernels/ai-server",
11+
"//kernels/ai-server",
1212
],
1313
deps = [
1414
"//:corevx",
15-
# "//kernels/ai-server:ai-server-kernels",
15+
"//kernels/ai_server:llm_kernels"
1616
],
1717
visibility = ["//visibility:public"]
1818
)

targets/ai_server/vx_chatbot.cpp

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,51 @@
99
*/
1010
#include <iostream>
1111
#include <string>
12+
#include <unordered_map>
1213

1314
#include <VX/vx.h>
1415
#include <VX/vx_compatibility.h>
1516
#include <VX/vx_helper.h>
1617
#include <VX/vx_lib_debug.h>
1718

19+
#include "chatbot.hpp"
1820
#include "vx_internal.h"
1921

22+
// Create an instance of ORT runner
23+
static const std::shared_ptr<RemoteModelClient> kernel = std::make_shared<RemoteModelClient>();
24+
25+
static std::unordered_map<std::string, const std::string> api_map = {
26+
{"chat", "/v1/chat/completions"},
27+
};
28+
2029
class VxRemoteModelClient
2130
{
31+
private:
32+
static vx_status store_vx_string_to_array(vx_array arr, const vx_string &in)
33+
{
34+
vx_status status = vxTruncateArray(arr, 0); // clear existing contents
35+
if (status != VX_SUCCESS)
36+
return status;
37+
38+
return vxAddArrayItems(arr, in.size(), in.data(), sizeof(char));
39+
}
40+
41+
static vx_status load_vx_string_from_array(vx_array arr, vx_string &out)
42+
{
43+
vx_size size = 0;
44+
vx_status status = vxQueryArray(arr, VX_ARRAY_ATTRIBUTE_NUMITEMS, &size, sizeof(size));
45+
if (status != VX_SUCCESS || size == 0)
46+
return VX_FAILURE;
47+
48+
out.resize(size); // allocate space directly in std::string
49+
status = vxCopyArrayRange(arr, 0, size, sizeof(char), out.data(), VX_READ_ONLY, VX_MEMORY_TYPE_HOST);
50+
return status;
51+
}
52+
2253
public:
2354
static constexpr vx_param_description_t kernelParams[] = {
24-
{VX_INPUT, VX_TYPE_STRING, VX_PARAMETER_STATE_REQUIRED}, // Parameter 0: Input text
25-
{VX_OUTPUT, VX_TYPE_STRING, VX_PARAMETER_STATE_REQUIRED}, // Parameter 1: Output text
55+
{VX_INPUT, VX_TYPE_ARRAY, VX_PARAMETER_STATE_REQUIRED}, // Parameter 0: Input text
56+
{VX_OUTPUT, VX_TYPE_ARRAY, VX_PARAMETER_STATE_REQUIRED}, // Parameter 1: Output text
2657
};
2758

2859
static vx_status VX_CALLBACK init(vx_node node, const vx_reference parameters[], vx_uint32 num)
@@ -47,7 +78,17 @@ class VxRemoteModelClient
4778
(void)node;
4879
(void)parameters;
4980
(void)num;
50-
return VX_SUCCESS;
81+
vx_status status = VX_SUCCESS;
82+
vx_string input_text, output_text;
83+
84+
status = load_vx_string_from_array((vx_array)parameters[0], input_text);
85+
status |= kernel->AiServerQuery(
86+
input_text, // Input text
87+
output_text, // Output text
88+
api_map["chat"]); // API path
89+
status |= store_vx_string_to_array((vx_array)parameters[1], output_text);
90+
91+
return status;
5192
}
5293
};
5394

0 commit comments

Comments
 (0)