Skip to content

Commit 409350e

Browse files
authored
Merge pull request #100 from chdb-io/binUdf
Support UDF in Python
2 parents 4061f73 + afc3faa commit 409350e

File tree

9 files changed

+280
-42
lines changed

9 files changed

+280
-42
lines changed

README-zh.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ sess.query(
106106
print("Select from view:\n")
107107
print(sess.query("SELECT * FROM db_xxx.view_xxx", "Pretty"))
108108
```
109-
109+
110+
参见: [test_stateful.py](tests/test_stateful.py)
110111
</details>
111112

112113
<details>
@@ -126,6 +127,23 @@ conn1.close()
126127
```
127128
</details>
128129

130+
<details>
131+
<summary><h4>🗂️ Query with UDF(User Defined Functions)</h4></summary>
132+
133+
```python
134+
from chdb.udf import chdb_udf
135+
from chdb import query
136+
137+
@chdb_udf()
138+
def sum_udf(lhs, rhs):
139+
return int(lhs) + int(rhs)
140+
141+
print(query("select sum_udf(12,22)"))
142+
```
143+
144+
参见: [test_udf.py](tests/test_udf.py).
145+
</details>
146+
129147
更多示例,请参见 [examples](examples)[tests](tests)
130148

131149
## 演示和示例

README.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ sess.query(
111111
print("Select from view:\n")
112112
print(sess.query("SELECT * FROM db_xxx.view_xxx", "Pretty"))
113113
```
114-
114+
115+
see also: [test_stateful.py](tests/test_stateful.py).
115116
</details>
116117

117118
<details>
@@ -132,6 +133,23 @@ conn1.close()
132133
</details>
133134

134135

136+
<details>
137+
<summary><h4>🗂️ Query with UDF(User Defined Functions)</h4></summary>
138+
139+
```python
140+
from chdb.udf import chdb_udf
141+
from chdb import query
142+
143+
@chdb_udf()
144+
def sum_udf(lhs, rhs):
145+
return int(lhs) + int(rhs)
146+
147+
print(query("select sum_udf(12,22)"))
148+
```
149+
150+
see also: [test_udf.py](tests/test_udf.py).
151+
</details>
152+
135153
For more examples, see [examples](examples) and [tests](tests).
136154

137155
<br>

chdb/__init__.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
import sys
22
import os
33

4+
5+
# If any UDF is defined, the path of the UDF will be set to this variable
6+
# and the path will be deleted when the process exits
7+
# UDF config path will be f"{g_udf_path}/udf_config.xml"
8+
# UDF script path will be f"{g_udf_path}/{func_name}.py"
9+
g_udf_path = ""
10+
411
chdb_version = (0, 6, 0)
512
if sys.version_info[:2] >= (3, 7):
613
# get the path of the current file
@@ -32,37 +39,30 @@ def to_arrowTable(res):
3239
import pyarrow as pa
3340
import pandas
3441
except ImportError as e:
35-
print(f'ImportError: {e}')
42+
print(f"ImportError: {e}")
3643
print('Please install pyarrow and pandas via "pip install pyarrow pandas"')
37-
raise ImportError('Failed to import pyarrow or pandas') from None
44+
raise ImportError("Failed to import pyarrow or pandas") from None
3845
if len(res) == 0:
3946
return pa.Table.from_batches([], schema=pa.schema([]))
4047
return pa.RecordBatchFileReader(res.bytes()).read_all()
4148

4249

4350
# return pandas dataframe
4451
def to_df(r):
45-
""""convert arrow table to Dataframe"""
52+
"""convert arrow table to Dataframe"""
4653
t = to_arrowTable(r)
4754
return t.to_pandas(use_threads=True)
4855

4956

5057
# wrap _chdb functions
51-
def query(sql, output_format="CSV"):
52-
lower_output_format = output_format.lower()
53-
if lower_output_format == "dataframe":
54-
return to_df(_chdb.query(sql, "Arrow"))
55-
elif lower_output_format == 'arrowtable':
56-
return to_arrowTable(_chdb.query(sql, "Arrow"))
57-
else:
58-
return _chdb.query(sql, output_format)
59-
60-
61-
def query_stateful(sql, output_format="CSV", path=None):
58+
def query(sql, output_format="CSV", path="", udf_path=""):
59+
global g_udf_path
60+
if udf_path != "":
61+
g_udf_path = udf_path
6262
lower_output_format = output_format.lower()
6363
if lower_output_format == "dataframe":
64-
return to_df(_chdb.query_stateful(sql, "Arrow", path))
65-
elif lower_output_format == 'arrowtable':
66-
return to_arrowTable(_chdb.query_stateful(sql, "Arrow", path))
64+
return to_df(_chdb.query(sql, "Arrow", path=path, udf_path=g_udf_path))
65+
elif lower_output_format == "arrowtable":
66+
return to_arrowTable(_chdb.query(sql, "Arrow", path=path, udf_path=g_udf_path))
6767
else:
68-
return _chdb.query_stateful(sql, output_format, path)
68+
return _chdb.query(sql, output_format, path=path, udf_path=g_udf_path)

chdb/session/state.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
import tempfile
22
import shutil
33

4-
from chdb import query_stateful
4+
from chdb import query, g_udf_path
55

66

7-
class Session():
7+
class Session:
88
"""
99
Session will keep the state of query. All DDL and DML state will be kept in a dir.
1010
Dir path could be passed in as an argument. If not, a temporary dir will be created.
1111
1212
If path is not specified, the temporary dir will be deleted when the Session object is deleted.
1313
Otherwise path will be kept.
14-
15-
Note: The default database is "_local" and the default engine is "Memory" which means all data
14+
15+
Note: The default database is "_local" and the default engine is "Memory" which means all data
1616
will be stored in memory. If you want to store data in disk, you should create another database.
1717
"""
1818

@@ -28,11 +28,20 @@ def __del__(self):
2828
if self._cleanup:
2929
self.cleanup()
3030

31+
def __enter__(self):
32+
return self
33+
34+
def __exit__(self, exc_type, exc_value, traceback):
35+
self.cleanup()
36+
3137
def cleanup(self):
32-
shutil.rmtree(self._path)
38+
try:
39+
shutil.rmtree(self._path)
40+
except:
41+
pass
3342

3443
def query(self, sql, fmt="CSV"):
3544
"""
3645
Execute a query.
3746
"""
38-
return query_stateful(sql, fmt, path=self._path)
47+
return query(sql, fmt, path=self._path)

chdb/udf/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .udf import *

chdb/udf/udf.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import functools
2+
import inspect
3+
import os
4+
import sys
5+
import tempfile
6+
import atexit
7+
import shutil
8+
import textwrap
9+
from xml.etree import ElementTree as ET
10+
import chdb
11+
12+
13+
def generate_udf(func_name, args, return_type, udf_body):
14+
# generate python script
15+
with open(f"{chdb.g_udf_path}/{func_name}.py", "w") as f:
16+
f.write(f"#!{sys.executable}\n")
17+
f.write("import sys\n")
18+
f.write("\n")
19+
for line in udf_body.split("\n"):
20+
f.write(f"{line}\n")
21+
f.write("\n")
22+
f.write("if __name__ == '__main__':\n")
23+
f.write(" for line in sys.stdin:\n")
24+
f.write(" args = line.strip().split('\t')\n")
25+
for i, arg in enumerate(args):
26+
f.write(f" {arg} = args[{i}]\n")
27+
f.write(f" print({func_name}({', '.join(args)}))\n")
28+
f.write(" sys.stdout.flush()\n")
29+
os.chmod(f"{chdb.g_udf_path}/{func_name}.py", 0o755)
30+
# generate xml file
31+
xml_file = f"{chdb.g_udf_path}/udf_config.xml"
32+
root = ET.Element("functions")
33+
if os.path.exists(xml_file):
34+
tree = ET.parse(xml_file)
35+
root = tree.getroot()
36+
function = ET.SubElement(root, "function")
37+
ET.SubElement(function, "type").text = "executable"
38+
ET.SubElement(function, "name").text = func_name
39+
ET.SubElement(function, "return_type").text = return_type
40+
ET.SubElement(function, "format").text = "TabSeparated"
41+
ET.SubElement(function, "command").text = f"{func_name}.py"
42+
for arg in args:
43+
argument = ET.SubElement(function, "argument")
44+
# We use TabSeparated format, so assume all arguments are strings
45+
ET.SubElement(argument, "type").text = "String"
46+
ET.SubElement(argument, "name").text = arg
47+
tree = ET.ElementTree(root)
48+
tree.write(xml_file)
49+
50+
51+
def chdb_udf(return_type="String"):
52+
"""
53+
Decorator for chDB Python UDF(User Defined Function).
54+
1. The function should be stateless. So, only UDFs are supported, not UDAFs(User Defined Aggregation Function).
55+
2. Default return type is String. If you want to change the return type, you can pass in the return type as an argument.
56+
The return type should be one of the following: https://clickhouse.com/docs/en/sql-reference/data-types
57+
3. The function should take in arguments of type String. As the input is TabSeparated, all arguments are strings.
58+
4. The function will be called for each line of input. Something like this:
59+
```
60+
def sum_udf(lhs, rhs):
61+
return int(lhs) + int(rhs)
62+
63+
for line in sys.stdin:
64+
args = line.strip().split('\t')
65+
lhs = args[0]
66+
rhs = args[1]
67+
print(sum_udf(lhs, rhs))
68+
sys.stdout.flush()
69+
```
70+
5. The function should be pure python function. You SHOULD import all python modules used IN THE FUNCTION.
71+
```
72+
def func_use_json(arg):
73+
import json
74+
...
75+
```
76+
6. Python interpertor used is the same as the one used to run the script. Get from `sys.executable`
77+
"""
78+
79+
def decorator(func):
80+
func_name = func.__name__
81+
sig = inspect.signature(func)
82+
args = list(sig.parameters.keys())
83+
src = inspect.getsource(func)
84+
src = textwrap.dedent(src)
85+
udf_body = src.split("\n", 1)[1] # remove the first line "@chdb_udf()"
86+
# create tmp dir and make sure the dir is deleted when the process exits
87+
if chdb.g_udf_path == "":
88+
chdb.g_udf_path = tempfile.mkdtemp()
89+
90+
# clean up the tmp dir on exit
91+
@atexit.register
92+
def _cleanup():
93+
try:
94+
shutil.rmtree(chdb.g_udf_path)
95+
except:
96+
pass
97+
98+
generate_udf(func_name, args, return_type, udf_body)
99+
100+
@functools.wraps(func)
101+
def wrapper(*args, **kwargs):
102+
return func(*args, **kwargs)
103+
104+
return wrapper
105+
106+
return decorator

programs/local/LocalChdb.cpp

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66
extern bool inside_main = true;
77

88

9-
local_result * queryToBuffer(const std::string & queryStr, const std::string & format = "CSV", const std::string & path = {})
9+
local_result * queryToBuffer(
10+
const std::string & queryStr,
11+
const std::string & output_format = "CSV",
12+
const std::string & path = {},
13+
const std::string & udfPath = {})
1014
{
11-
std::vector<std::string> argv = {"clickhouse", "--multiquery"};
15+
std::vector<std::string> argv = {"clickhouse", "--", "--multiquery"};
1216

13-
// if format is "Debug" or "debug", then we will add --verbose and --log-level=trace to argv
14-
if (format == "Debug" || format == "debug")
17+
// If format is "Debug" or "debug", then we will add `--verbose` and `--log-level=trace` to argv
18+
if (output_format == "Debug" || output_format == "debug")
1519
{
1620
argv.push_back("--verbose");
1721
argv.push_back("--log-level=trace");
@@ -21,10 +25,19 @@ local_result * queryToBuffer(const std::string & queryStr, const std::string & f
2125
else
2226
{
2327
// Add format string
24-
argv.push_back("--output-format=" + format);
28+
argv.push_back("--output-format=" + output_format);
2529
}
2630

27-
if (!path.empty())
31+
// If udfPath is not empty, then we will add `--user_scripts_path` and `--user_defined_executable_functions_config` to argv
32+
// the path should be a one time thing, so the caller should take care of the temporary files deletion
33+
if (!udfPath.empty())
34+
{
35+
argv.push_back("--user_scripts_path=" + udfPath);
36+
argv.push_back("--user_defined_executable_functions_config=" + udfPath + "/*.xml");
37+
}
38+
39+
// If path is not empty, then we will add `--path` to argv. This is used for chdb.Session to support stateful query
40+
if (!path.empty())
2841
{
2942
// Add path string
3043
argv.push_back("--path=" + path);
@@ -42,14 +55,13 @@ local_result * queryToBuffer(const std::string & queryStr, const std::string & f
4255

4356
// Pybind11 will take over the ownership of the `query_result` object
4457
// using smart ptr will cause early free of the object
45-
query_result * query(const std::string & queryStr, const std::string & format = "CSV")
58+
query_result * query(
59+
const std::string & queryStr,
60+
const std::string & output_format = "CSV",
61+
const std::string & path = {},
62+
const std::string & udfPath = {})
4663
{
47-
return new query_result(queryToBuffer(queryStr, format));
48-
}
49-
50-
query_result * query_stateful(const std::string & queryStr, const std::string & format = "CSV", const std::string & path = {})
51-
{
52-
return new query_result(queryToBuffer(queryStr, format, path));
64+
return new query_result(queryToBuffer(queryStr, output_format, path, udfPath));
5365
}
5466

5567
// The `query_result` and `memoryview_wrapper` will hold `local_result_wrapper` with shared_ptr
@@ -132,9 +144,15 @@ PYBIND11_MODULE(_chdb, m)
132144
.def("get_memview", &query_result::get_memview);
133145

134146

135-
m.def("query", &query, "Stateless query Clickhouse and return a query_result object");
136-
137-
m.def("query_stateful", &query_stateful, "Stateful query Clickhouse and return a query_result object");
147+
m.def(
148+
"query",
149+
&query,
150+
py::arg("queryStr"),
151+
py::arg("output_format") = "CSV",
152+
py::kw_only(),
153+
py::arg("path") = "",
154+
py::arg("udf_path") = "",
155+
"Query chDB and return a query_result object");
138156
}
139157

140158
#endif // PY_TEST_MAIN

tests/test_stateful.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,16 @@ def test_tmp(self):
9393
ret = sess2.query("SELECT chdb_xxx()", "CSV")
9494
self.assertEqual(str(ret), "")
9595

96+
def test_context_mgr(self):
97+
with session.Session() as sess:
98+
sess.query("CREATE FUNCTION chdb_xxx AS () -> '0.12.0'", "CSV")
99+
ret = sess.query("SELECT chdb_xxx()", "CSV")
100+
self.assertEqual(str(ret), '"0.12.0"\n')
101+
102+
with session.Session() as sess:
103+
ret = sess.query("SELECT chdb_xxx()", "CSV")
104+
self.assertEqual(str(ret), "")
105+
96106
def test_zfree_thread_count(self):
97107
time.sleep(3)
98108
thread_count = current_process.num_threads()

0 commit comments

Comments
 (0)