Skip to content

Commit 0d61001

Browse files
committed
Use cbindgen to build the bindings
1 parent 8043d8e commit 0d61001

File tree

5 files changed

+144
-53
lines changed

5 files changed

+144
-53
lines changed

Makefile

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ RUST_LIB := infera/target/release/$(EXT_NAME).a
77
DUCKDB_SRCDIR := ./external/duckdb/
88
EXT_CONFIG := ${PROJ_DIR}extension_config.cmake
99
TESTS_DIR := tests
10+
EXAMPLES_DIR := docs/examples
1011
SHELL := /bin/bash
1112
PYTHON := python3
1213

@@ -76,8 +77,8 @@ rust-clean: ## Clean Rust build artifacts
7677
.PHONY: create-bindings
7778
create-bindings: ## Generate C bindings from Rust code
7879
@echo "Generating C bindings for Infera..."
79-
@cd infera && cbindgen --config cbindgen.toml --crate infera --output bindings/include/_rust.h
80-
@echo "C bindings generated at infera/bindings/include/_rust.h"
80+
@cd infera && cbindgen --config cbindgen.toml --crate infera --output bindings/include/rust.h
81+
@echo "C bindings generated at infera/bindings/include/rust.h"
8182

8283
# ==============================================================================
8384
# Targets for Building the Extension
@@ -138,11 +139,11 @@ clean-all: clean rust-clean ## Clean everything
138139
check: rust-lint rust-test ## Run all checks
139140
@echo "All checks passed!"
140141

141-
.PHONY: test-infera
142-
test-infera: ## Run SQL tests for the Infera extension
143-
@echo "Running every SQL file in the `tests/sql` directory..."
144-
@for sql_file in $(TESTS_DIR)/sql/*.sql; do \
145-
echo "Running test: $$sql_file"; \
142+
.PHONY: examples
143+
examples: ## Run SQL examples for Infera extension
144+
@echo "Running the examples in the ${EXAMPLES_DIR} directory..."
145+
@for sql_file in $(EXAMPLES_DIR)/*.sql; do \
146+
echo "Running example: $$sql_file"; \
146147
./build/release/duckdb < $$sql_file; \
147148
echo "============================================================================"; \
148149
done

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ Infera is DuckDB extension that lets you use machine learning (ML) models direct
2323
on data stored in DuckDB tables.
2424
It is developed in Rust and uses [Tract](https://github.com/snipsco/tract) as the backend inference engine.
2525
Infera supports loading and running models in [ONNX](https://onnx.ai/) format.
26-
Check out the [ONNX Model Zoo](https://huggingface.co/onnxmodelzoo) repositors on Hugging Face for a large collection of
27-
ready-to-use models (more than 1700) that can be used with Infera.
26+
Check out the [ONNX Model Zoo](https://huggingface.co/onnxmodelzoo) repositors on Hugging Face for a very large collection of
27+
ready-to-use models that can be used with Infera.
2828

2929
### Motivation
3030

infera/bindings/include/rust.h

Lines changed: 112 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,130 @@
1+
/* Generated with cbindgen */
2+
/* DO NOT EDIT */
3+
4+
15
#ifndef INFERA_H
26
#define INFERA_H
37

48
#pragma once
59

6-
#include <stddef.h>
10+
/* Generated with cbindgen:0.26.0 */
11+
12+
#include <stdarg.h>
13+
#include <stdbool.h>
714
#include <stdint.h>
15+
#include <stdlib.h>
816

917
#ifdef __cplusplus
10-
extern "C" {
11-
#endif
18+
namespace infera {
19+
#endif // __cplusplus
1220

13-
typedef struct {
21+
typedef struct InferaInferenceResult {
1422
float *data;
15-
size_t len;
16-
size_t rows;
17-
size_t cols;
23+
uintptr_t len;
24+
uintptr_t rows;
25+
uintptr_t cols;
1826
int32_t status;
1927
} InferaInferenceResult;
2028

21-
void infera_free(char *ptr);
22-
void infera_free_result(InferaInferenceResult result);
23-
const char *infera_last_error(void);
24-
char *infera_get_loaded_models(void);
25-
char *infera_get_model_info(const char *model_name);
26-
char *infera_get_version(void);
27-
char *infera_set_autoload_dir(const char *path);
28-
int32_t infera_load_model(const char *name, const char *path);
29-
int32_t infera_unload_model(const char *name);
30-
InferaInferenceResult infera_predict(const char *model_name, const float *data,
31-
size_t rows, size_t cols);
32-
InferaInferenceResult infera_predict_from_blob(const char *model_name,
33-
const uint8_t *blob_data,
34-
size_t blob_len);
29+
#ifdef __cplusplus
30+
extern "C" {
31+
#endif // __cplusplus
32+
33+
/**
34+
* Loads a model from a file or URL.
35+
*
36+
* # Safety
37+
* The `name` and `path` pointers must be valid, null-terminated C strings.
38+
*/
39+
int32_t infera_load_model(const char *name, const char *path);
40+
41+
/**
42+
* Unloads a model.
43+
*
44+
* # Safety
45+
* The `name` pointer must be a valid, null-terminated C string.
46+
*/
47+
int32_t infera_unload_model(const char *name);
48+
49+
/**
50+
* Runs inference on a model with the given input data.
51+
*
52+
* # Safety
53+
* The `model_name` and `data` pointers must be valid. `model_name` must be a
54+
* null-terminated C string. `data` must point to a contiguous block of memory
55+
* of size `rows * cols * size_of<f32>()`.
56+
*/
57+
58+
struct InferaInferenceResult infera_predict(const char *model_name,
59+
const float *data,
60+
uintptr_t rows,
61+
uintptr_t cols);
62+
63+
/**
64+
* Runs inference on a model with input data from a BLOB.
65+
*
66+
* # Safety
67+
* The `model_name` and `blob_data` pointers must be valid. `model_name` must be a
68+
* null-terminated C string. `blob_data` must point to a contiguous block of
69+
* memory of size `blob_len`.
70+
*/
71+
72+
struct InferaInferenceResult infera_predict_from_blob(const char *model_name,
73+
const uint8_t *blob_data,
74+
uintptr_t blob_len);
75+
76+
/**
77+
* Gets information about a loaded model.
78+
*
79+
* # Safety
80+
* The `model_name` pointer must be a valid, null-terminated C string.
81+
*/
82+
char *infera_get_model_info(const char *model_name);
83+
84+
char *infera_get_loaded_models(void);
85+
86+
char *infera_get_version(void);
87+
88+
/**
89+
* Sets a directory to automatically load models from.
90+
*
91+
* # Safety
92+
* The `path` pointer must be a valid, null-terminated C string.
93+
*/
94+
char *infera_set_autoload_dir(const char *path);
95+
96+
const char *infera_last_error(void);
97+
98+
/**
99+
* Frees a C string that was allocated by Rust.
100+
*
101+
* # Safety
102+
*
103+
* The `ptr` must be a pointer to a C string that was allocated by Rust's
104+
* `CString::into_raw`. Calling this function with a pointer that was not
105+
* allocated by `CString::into_raw` will result in undefined behavior.
106+
*/
107+
void infera_free(char *ptr);
108+
109+
/**
110+
* Frees the memory allocated for an `InferaInferenceResult`.
111+
*
112+
* # Safety
113+
*
114+
* The `res.data` pointer must have been allocated by Rust's `Vec` and the `res.len`
115+
* must be the correct length of the allocated memory. Calling this function
116+
* with a result that was not created by this library can lead to undefined behavior.
117+
*/
118+
void infera_free_result(struct InferaInferenceResult res);
119+
120+
#ifdef __cplusplus
121+
} // extern "C"
122+
#endif // __cplusplus
35123

36124
#ifdef __cplusplus
37-
}
125+
} // namespace infera
38126
#endif // __cplusplus
39127

40128
#endif /* INFERA_H */
129+
130+
/* End of generated bindings */

infera/bindings/infera_extension.cpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
namespace duckdb {
2525

2626
static std::string GetInferaError() {
27-
const char *err = infera_last_error();
27+
const char *err = infera::infera_last_error();
2828
return err ? std::string(err) : std::string("unknown error");
2929
}
3030

@@ -38,19 +38,19 @@ static void SetAutoloadDir(DataChunk &args, ExpressionState &state, Vector &resu
3838
throw InvalidInputException("Path cannot be NULL");
3939
}
4040
std::string path_str = path_val.ToString();
41-
char *result_json_c = infera_set_autoload_dir(path_str.c_str());
41+
char *result_json_c = infera::infera_set_autoload_dir(path_str.c_str());
4242
result.SetVectorType(VectorType::CONSTANT_VECTOR);
4343
ConstantVector::GetData<string_t>(result)[0] = StringVector::AddString(result, result_json_c);
4444
ConstantVector::SetNull(result, false);
45-
infera_free(result_json_c);
45+
infera::infera_free(result_json_c);
4646
}
4747

4848
static void GetVersion(DataChunk &args, ExpressionState &state, Vector &result) {
49-
char *info_json_c = infera_get_version();
49+
char *info_json_c = infera::infera_get_version();
5050
result.SetVectorType(VectorType::CONSTANT_VECTOR);
5151
ConstantVector::GetData<string_t>(result)[0] = StringVector::AddString(result, info_json_c);
5252
ConstantVector::SetNull(result, false);
53-
infera_free(info_json_c);
53+
infera::infera_free(info_json_c);
5454
}
5555

5656
static void LoadModel(DataChunk &args, ExpressionState &state, Vector &result) {
@@ -68,7 +68,7 @@ static void LoadModel(DataChunk &args, ExpressionState &state, Vector &result) {
6868
if (model_name_str.empty()) {
6969
throw InvalidInputException("Model name cannot be empty");
7070
}
71-
int rc = infera_load_model(model_name_str.c_str(), path_str.c_str());
71+
int rc = infera::infera_load_model(model_name_str.c_str(), path_str.c_str());
7272
bool success = rc == 0;
7373
if (!success) {
7474
throw InvalidInputException("Failed to load model '" + model_name_str + "': " + GetInferaError());
@@ -88,7 +88,7 @@ static void UnloadModel(DataChunk &args, ExpressionState &state, Vector &result)
8888
throw InvalidInputException("Model name cannot be NULL");
8989
}
9090
std::string model_name_str = model_name.ToString();
91-
int rc = infera_unload_model(model_name_str.c_str());
91+
int rc = infera::infera_unload_model(model_name_str.c_str());
9292
bool success = (rc == 0);
9393
if (!success) {
9494
throw InvalidInputException("Failed to unload model '" + model_name_str + "': " + GetInferaError());
@@ -143,21 +143,21 @@ static void Predict(DataChunk &args, ExpressionState &state, Vector &result) {
143143
std::vector<float> features;
144144
ExtractFeatures(args, features);
145145

146-
InferaInferenceResult res = infera_predict(model_name_str.c_str(), features.data(), batch_size, feature_count);
146+
infera::InferaInferenceResult res = infera::infera_predict(model_name_str.c_str(), features.data(), batch_size, feature_count);
147147
if (res.status != 0) {
148148
throw InvalidInputException("Inference failed for model '" + model_name_str + "': " + GetInferaError());
149149
}
150150
if (res.rows != batch_size || res.cols != 1) {
151151
std::string err_msg = StringUtil::Format("Model output shape mismatch. Expected (%d, 1), but got (%d, %d).", batch_size, res.rows, res.cols);
152-
infera_free_result(res);
152+
infera::infera_free_result(res);
153153
throw InvalidInputException(err_msg);
154154
}
155155
result.SetVectorType(VectorType::FLAT_VECTOR);
156156
auto result_data = FlatVector::GetData<float>(result);
157157
for (idx_t i = 0; i < batch_size; i++) {
158158
result_data[i] = res.data[i];
159159
}
160-
infera_free_result(res);
160+
infera::infera_free_result(res);
161161
}
162162

163163
static void PredictFromBlob(DataChunk &args, ExpressionState &state, Vector &result) {
@@ -177,9 +177,9 @@ static void PredictFromBlob(DataChunk &args, ExpressionState &state, Vector &res
177177
string_t blob_str_t = blob_val.GetValueUnsafe<string_t>();
178178
auto blob_ptr = reinterpret_cast<const uint8_t *>(blob_str_t.GetDataUnsafe());
179179
auto blob_len = blob_str_t.GetSize();
180-
InferaInferenceResult res = infera_predict_from_blob(model_name_str.c_str(), blob_ptr, blob_len);
180+
infera::InferaInferenceResult res = infera::infera_predict_from_blob(model_name_str.c_str(), blob_ptr, blob_len);
181181
if (res.status != 0) {
182-
infera_free_result(res);
182+
infera::infera_free_result(res);
183183
throw InvalidInputException("Inference failed for model '" + model_name_str + "': " + GetInferaError());
184184
}
185185
std::vector<Value> elems;
@@ -188,17 +188,17 @@ static void PredictFromBlob(DataChunk &args, ExpressionState &state, Vector &res
188188
elems.emplace_back(Value::FLOAT(res.data[j]));
189189
}
190190
result.SetValue(i, Value::LIST(std::move(elems)));
191-
infera_free_result(res);
191+
infera::infera_free_result(res);
192192
}
193193
result.Verify(args.size());
194194
}
195195

196196
static void GetLoadedModels(DataChunk &args, ExpressionState &state, Vector &result) {
197-
char *models_json = infera_get_loaded_models();
197+
char *models_json = infera::infera_get_loaded_models();
198198
result.SetVectorType(VectorType::CONSTANT_VECTOR);
199199
ConstantVector::GetData<string_t>(result)[0] = StringVector::AddString(result, models_json);
200200
ConstantVector::SetNull(result, false);
201-
infera_free(models_json);
201+
infera::infera_free(models_json);
202202
}
203203

204204
static void PredictMulti(DataChunk &args, ExpressionState &state, Vector &result) {
@@ -211,14 +211,14 @@ static void PredictMulti(DataChunk &args, ExpressionState &state, Vector &result
211211
std::vector<float> features;
212212
ExtractFeatures(args, features);
213213

214-
InferaInferenceResult res = infera_predict(model_name_str.c_str(), features.data(), batch_size, feature_count);
214+
infera::InferaInferenceResult res = infera::infera_predict(model_name_str.c_str(), features.data(), batch_size, feature_count);
215215
if (res.status != 0) {
216-
infera_free_result(res);
216+
infera::infera_free_result(res);
217217
throw InvalidInputException("Inference failed for model '" + model_name_str + "': " + GetInferaError());
218218
}
219219
if (res.rows != batch_size) {
220220
std::string err_msg = StringUtil::Format("Model output row count mismatch. Expected %d, but got %d.", batch_size, res.rows);
221-
infera_free_result(res);
221+
infera::infera_free_result(res);
222222
throw InvalidInputException(err_msg);
223223
}
224224
result.SetVectorType(VectorType::FLAT_VECTOR);
@@ -236,7 +236,7 @@ static void PredictMulti(DataChunk &args, ExpressionState &state, Vector &result
236236
oss << "]";
237237
result_data[row_idx] = StringVector::AddString(result, oss.str());
238238
}
239-
infera_free_result(res);
239+
infera::infera_free_result(res);
240240
}
241241

242242
static void GetModelInfo(DataChunk &args, ExpressionState &state, Vector &result) {
@@ -249,12 +249,12 @@ static void GetModelInfo(DataChunk &args, ExpressionState &state, Vector &result
249249
throw InvalidInputException("Model name cannot be NULL");
250250
}
251251
std::string model_name_str = model_name.ToString();
252-
char *json_meta = infera_get_model_info(model_name_str.c_str());
252+
char *json_meta = infera::infera_get_model_info(model_name_str.c_str());
253253

254254
result.SetVectorType(VectorType::CONSTANT_VECTOR);
255255
ConstantVector::GetData<string_t>(result)[0] = StringVector::AddString(result, json_meta);
256256
ConstantVector::SetNull(result, false);
257-
infera_free(json_meta);
257+
infera::infera_free(json_meta);
258258
}
259259

260260
static void LoadInternal(ExtensionLoader &loader) {

infera/cbindgen.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace = "infera"
77
cpp_compat = true
88

99
# Output settings
10-
tab_width = 4
10+
tab_width = 2
1111
line_length = 100
1212
braces = "SameLine"
1313
include_guard = "INFERA_H"

0 commit comments

Comments
 (0)