Skip to content

Commit 7a31982

Browse files
committed
fix: update to new extension interface
1 parent 536f4e4 commit 7a31982

File tree

5 files changed

+123
-54
lines changed

5 files changed

+123
-54
lines changed

duckdb

Submodule duckdb updated 6772 files

src/include/open_prompt_extension.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
#pragma once
22

33
#include "duckdb.hpp"
4+
#include "duckdb/main/extension/extension_loader.hpp"
45

56
namespace duckdb {
67

7-
using HeaderMap = case_insensitive_map_t<string>;
8-
98
class OpenPromptExtension : public Extension {
109
public:
11-
void Load(DuckDB &db) override;
10+
void Load(ExtensionLoader &loader) override;
1211
std::string Name() override;
1312
std::string Version() const override;
1413

src/include/open_prompt_secret.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
#pragma once
22

33
#include "duckdb/main/secret/secret.hpp"
4-
#include "duckdb/main/extension_util.hpp"
4+
#include "duckdb/main/extension/extension_loader.hpp"
55

66
namespace duckdb {
77

88
struct CreateOpenPromptSecretFunctions {
99
public:
10-
static void Register(DatabaseInstance &instance);
10+
static void Register(ExtensionLoader &loader);
1111
};
1212

1313
} // namespace duckdb

src/open_prompt_extension.cpp

Lines changed: 115 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#include "open_prompt_extension.hpp"
33
#include "duckdb.hpp"
44
#include "duckdb/function/scalar_function.hpp"
5-
#include "duckdb/main/extension_util.hpp"
5+
#include "duckdb/main/extension/extension_loader.hpp"
66
#include "duckdb/common/atomic.hpp"
77
#include "duckdb/common/exception/http_exception.hpp"
88
#include <duckdb/parser/parsed_data/create_scalar_function_info.hpp>
@@ -45,7 +45,7 @@ namespace duckdb {
4545
res->json_system_prompt_idx = json_system_prompt_idx;
4646
return unique_ptr<FunctionData>(std::move(res));
4747
};
48-
bool Equals(const FunctionData &other) const {
48+
bool Equals(const FunctionData &other) const override {
4949
return model_idx == other.Cast<OpenPromptData>().model_idx &&
5050
json_schema_idx == other.Cast<OpenPromptData>().json_schema_idx &&
5151
json_system_prompt_idx==other.Cast<OpenPromptData>().json_system_prompt_idx;
@@ -401,42 +401,118 @@ namespace duckdb {
401401
});
402402
}
403403

404-
// Complete LoadInternal function
405-
static void LoadInternal(DatabaseInstance &instance) {
406-
ScalarFunctionSet open_prompt("open_prompt");
407-
408-
open_prompt.AddFunction(ScalarFunction(
409-
{LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction,
410-
OpenPromptBind));
411-
open_prompt.AddFunction(ScalarFunction(
412-
{LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction,
413-
OpenPromptBind));
414-
open_prompt.AddFunction(ScalarFunction(
415-
{LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR},
416-
LogicalType::VARCHAR, OpenPromptRequestFunction,
417-
OpenPromptBind));
418-
open_prompt.AddFunction(ScalarFunction(
419-
{LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR},
420-
LogicalType::VARCHAR, OpenPromptRequestFunction,
421-
OpenPromptBind));
422-
423-
// Register Secret functions
424-
CreateOpenPromptSecretFunctions::Register(instance);
425-
426-
ExtensionUtil::RegisterFunction(instance, open_prompt);
427-
428-
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
429-
"set_api_token", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiToken));
430-
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
431-
"set_api_url", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiUrl));
432-
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
433-
"set_model_name", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetModelName));
434-
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
435-
"set_api_timeout", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiTimeout));
404+
static void LoadInternal(ExtensionLoader &loader) {
405+
// Create open_prompt function set
406+
ScalarFunctionSet open_prompt_set("open_prompt");
407+
408+
// Single argument: prompt only
409+
open_prompt_set.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR,
410+
OpenPromptRequestFunction, OpenPromptBind));
411+
412+
// Two arguments: prompt + model
413+
open_prompt_set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR},
414+
LogicalType::VARCHAR, OpenPromptRequestFunction, OpenPromptBind));
415+
416+
// Three arguments: prompt + model + json_schema
417+
open_prompt_set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR},
418+
LogicalType::VARCHAR, OpenPromptRequestFunction, OpenPromptBind));
419+
420+
// Four arguments: prompt + model + json_schema + system_prompt
421+
open_prompt_set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR},
422+
LogicalType::VARCHAR, OpenPromptRequestFunction, OpenPromptBind));
423+
424+
// Create function info with documentation
425+
CreateScalarFunctionInfo open_prompt_info(open_prompt_set);
426+
427+
// Document single argument variant
428+
FunctionDescription desc1;
429+
desc1.parameter_names = {"prompt"};
430+
desc1.parameter_types = {LogicalType::VARCHAR};
431+
desc1.description = "Send a prompt to an OpenAI-compatible LLM API and return the response";
432+
desc1.examples = {"open_prompt('What is DuckDB?')"};
433+
desc1.categories = {"ai"};
434+
open_prompt_info.descriptions.push_back(desc1);
435+
436+
// Document two argument variant
437+
FunctionDescription desc2;
438+
desc2.parameter_names = {"prompt", "model"};
439+
desc2.parameter_types = {LogicalType::VARCHAR, LogicalType::VARCHAR};
440+
desc2.description = "Send a prompt to an LLM API using a specific model";
441+
desc2.examples = {"open_prompt('Explain SQL', 'gpt-4')"};
442+
desc2.categories = {"ai"};
443+
open_prompt_info.descriptions.push_back(desc2);
444+
445+
// Document three argument variant
446+
FunctionDescription desc3;
447+
desc3.parameter_names = {"prompt", "model", "json_schema"};
448+
desc3.parameter_types = {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR};
449+
desc3.description = "Send a prompt to an LLM API with structured JSON output";
450+
desc3.examples = {"open_prompt('Extract name', 'gpt-4', '{\"type\":\"object\"}')"};
451+
desc3.categories = {"ai"};
452+
open_prompt_info.descriptions.push_back(desc3);
453+
454+
// Document four argument variant
455+
FunctionDescription desc4;
456+
desc4.parameter_names = {"prompt", "model", "json_schema", "system_prompt"};
457+
desc4.parameter_types = {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR};
458+
desc4.description = "Send a prompt to an LLM API with a system prompt and structured output";
459+
desc4.examples = {"open_prompt('Hello', 'gpt-4', '{}', 'You are helpful')"};
460+
desc4.categories = {"ai"};
461+
open_prompt_info.descriptions.push_back(desc4);
462+
463+
loader.RegisterFunction(open_prompt_info);
464+
465+
// Register Secret functions
466+
CreateOpenPromptSecretFunctions::Register(loader);
467+
468+
// Configuration helper functions with documentation
469+
CreateScalarFunctionInfo set_token_info(ScalarFunction("set_api_token", {LogicalType::VARCHAR},
470+
LogicalType::VARCHAR, SetApiToken));
471+
FunctionDescription set_token_desc;
472+
set_token_desc.parameter_names = {"token"};
473+
set_token_desc.parameter_types = {LogicalType::VARCHAR};
474+
set_token_desc.description = "Set the API token for LLM authentication";
475+
set_token_desc.examples = {"set_api_token('sk-...')"};
476+
set_token_desc.categories = {"ai", "configuration"};
477+
set_token_info.descriptions.push_back(set_token_desc);
478+
loader.RegisterFunction(set_token_info);
479+
480+
CreateScalarFunctionInfo set_url_info(ScalarFunction("set_api_url", {LogicalType::VARCHAR},
481+
LogicalType::VARCHAR, SetApiUrl));
482+
FunctionDescription set_url_desc;
483+
set_url_desc.parameter_names = {"url"};
484+
set_url_desc.parameter_types = {LogicalType::VARCHAR};
485+
set_url_desc.description = "Set the API URL for LLM endpoint";
486+
set_url_desc.examples = {"set_api_url('https://api.openai.com/v1/chat/completions')"};
487+
set_url_desc.categories = {"ai", "configuration"};
488+
set_url_info.descriptions.push_back(set_url_desc);
489+
loader.RegisterFunction(set_url_info);
490+
491+
CreateScalarFunctionInfo set_model_info(ScalarFunction("set_model_name", {LogicalType::VARCHAR},
492+
LogicalType::VARCHAR, SetModelName));
493+
FunctionDescription set_model_desc;
494+
set_model_desc.parameter_names = {"model"};
495+
set_model_desc.parameter_types = {LogicalType::VARCHAR};
496+
set_model_desc.description = "Set the default model name for LLM requests";
497+
set_model_desc.examples = {"set_model_name('gpt-4')"};
498+
set_model_desc.categories = {"ai", "configuration"};
499+
set_model_info.descriptions.push_back(set_model_desc);
500+
loader.RegisterFunction(set_model_info);
501+
502+
CreateScalarFunctionInfo set_timeout_info(ScalarFunction("set_api_timeout", {LogicalType::VARCHAR},
503+
LogicalType::VARCHAR, SetApiTimeout));
504+
FunctionDescription set_timeout_desc;
505+
set_timeout_desc.parameter_names = {"timeout_seconds"};
506+
set_timeout_desc.parameter_types = {LogicalType::VARCHAR};
507+
set_timeout_desc.description = "Set the API timeout in seconds for LLM requests";
508+
set_timeout_desc.examples = {"set_api_timeout('30')"};
509+
set_timeout_desc.categories = {"ai", "configuration"};
510+
set_timeout_info.descriptions.push_back(set_timeout_desc);
511+
loader.RegisterFunction(set_timeout_info);
436512
}
437513

438-
void OpenPromptExtension::Load(DuckDB &db) {
439-
LoadInternal(*db.instance);
514+
void OpenPromptExtension::Load(ExtensionLoader &loader) {
515+
LoadInternal(loader);
440516
}
441517

442518
std::string OpenPromptExtension::Name() {
@@ -454,14 +530,9 @@ namespace duckdb {
454530
} // namespace duckdb
455531

456532
extern "C" {
457-
DUCKDB_EXTENSION_API void open_prompt_init(duckdb::DatabaseInstance &db) {
458-
duckdb::DuckDB db_wrapper(db);
459-
db_wrapper.LoadExtension<duckdb::OpenPromptExtension>();
460-
}
461-
462-
DUCKDB_EXTENSION_API const char *open_prompt_version() {
463-
return duckdb::DuckDB::LibraryVersion();
464-
}
533+
DUCKDB_CPP_EXTENSION_ENTRY(open_prompt, loader) {
534+
duckdb::LoadInternal(loader);
535+
}
465536
}
466537

467538
#ifndef DUCKDB_EXTENSION_MAIN

src/open_prompt_secret.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include "open_prompt_secret.hpp"
22
#include "duckdb/common/exception.hpp"
33
#include "duckdb/main/secret/secret.hpp"
4-
#include "duckdb/main/extension_util.hpp"
54

65
namespace duckdb {
76

@@ -41,20 +40,20 @@ static unique_ptr<BaseSecret> CreateOpenPromptSecretFromConfig(ClientContext &co
4140
return std::move(result);
4241
}
4342

44-
void CreateOpenPromptSecretFunctions::Register(DatabaseInstance &instance) {
43+
void CreateOpenPromptSecretFunctions::Register(ExtensionLoader &loader) {
4544
string type = "open_prompt";
4645

4746
// Register the new type
4847
SecretType secret_type;
4948
secret_type.name = type;
5049
secret_type.deserializer = KeyValueSecret::Deserialize<KeyValueSecret>;
5150
secret_type.default_provider = "config";
52-
ExtensionUtil::RegisterSecretType(instance, secret_type);
51+
loader.RegisterSecretType(secret_type);
5352

5453
// Register the config secret provider
5554
CreateSecretFunction config_function = {type, "config", CreateOpenPromptSecretFromConfig};
5655
RegisterCommonSecretParameters(config_function);
57-
ExtensionUtil::RegisterFunction(instance, config_function);
56+
loader.RegisterFunction(config_function);
5857
}
5958

6059
} // namespace duckdb

0 commit comments

Comments
 (0)