Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 85dd6e0

Browse files
committed
Enable SQL queries on result sets.
Signed-off-by: ienkovich <[email protected]>
1 parent 5cd68ef commit 85dd6e0

File tree

7 files changed

+72
-18
lines changed

7 files changed

+72
-18
lines changed

omniscidb/Calcite/SchemaJson.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -188,29 +188,28 @@ std::string schema_to_json(SchemaProvider* schema_provider) {
188188
if (dbs.empty()) {
189189
return "{}";
190190
}
191-
// Current JSON format supports a single database only. So, we exclude
192-
// ResultSetRegistry from the schema for now (which makes it impossible
193-
// to run SQL queries on result sets).
194-
int db_id;
195-
if (dbs.size() == (size_t)1) {
196-
db_id = dbs.front();
197-
} else {
191+
// Current JSON format supports a single database only. To support result
192+
// sets in SQL queries, we add tables from the ResultSetRegistry using
193+
// negative table ids.
194+
auto tables = schema_provider->listTables(dbs[0]);
195+
if (dbs.size() != (size_t)1) {
198196
CHECK_EQ(dbs.size(), (size_t)2);
199197
CHECK(dbs[0] == hdk::ResultSetRegistry::DB_ID ||
200198
dbs[1] == hdk::ResultSetRegistry::DB_ID);
201-
db_id = dbs[0] == hdk::ResultSetRegistry::DB_ID ? dbs[1] : dbs[0];
199+
auto more_tables = schema_provider->listTables(dbs[1]);
200+
tables.insert(tables.end(), more_tables.begin(), more_tables.end());
202201
}
203202

204-
auto tables = schema_provider->listTables(db_id);
205-
206203
rapidjson::Document doc(rapidjson::kObjectType);
207204

208205
for (auto tinfo : tables) {
209206
rapidjson::Value table(rapidjson::kObjectType);
210207
table.AddMember("name",
211208
rapidjson::Value().SetString(rapidjson::StringRef(tinfo->name)),
212209
doc.GetAllocator());
213-
table.AddMember("id", rapidjson::Value().SetInt(tinfo->table_id), doc.GetAllocator());
210+
int table_id = tinfo->db_id == hdk::ResultSetRegistry::DB_ID ? -tinfo->table_id
211+
: tinfo->table_id;
212+
table.AddMember("id", rapidjson::Value().SetInt(table_id), doc.GetAllocator());
214213
table.AddMember(
215214
"columns", rapidjson::Value(rapidjson::kArrayType), doc.GetAllocator());
216215

omniscidb/QueryEngine/RelAlgDagBuilder.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "JsonAccessors.h"
2828
#include "RelAlgDagBuilder.h"
2929
#include "RelAlgOptimizer.h"
30+
#include "ResultSetRegistry/ResultSetRegistry.h"
3031
#include "ScalarExprVisitor.h"
3132
#include "Shared/sqldefs.h"
3233

@@ -2019,7 +2020,13 @@ TableInfoPtr getTableFromScanNode(int db_id,
20192020
const auto& table_json = field(scan_ra, "table");
20202021
CHECK(table_json.IsArray());
20212022
CHECK_EQ(unsigned(2), table_json.Size());
2022-
const auto info = schema_provider->getTableInfo(db_id, table_json[1].GetString());
2023+
auto info = schema_provider->getTableInfo(db_id, table_json[1].GetString());
2024+
// If table wasn't found in the default database, then try search in the
2025+
// result set registry.
2026+
if (!info) {
2027+
info = schema_provider->getTableInfo(hdk::ResultSetRegistry::DB_ID,
2028+
table_json[1].GetString());
2029+
}
20232030
CHECK(info);
20242031
return info;
20252032
}

omniscidb/Tests/ArrowSQLRunner/ArrowSQLRunner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class ArrowSQLRunnerImpl {
8888
std::string query_ra = getSqlQueryRelAlg(sql);
8989

9090
auto dag =
91-
std::make_unique<RelAlgDagBuilder>(query_ra, TEST_DB_ID, storage_, config_);
91+
std::make_unique<RelAlgDagBuilder>(query_ra, TEST_DB_ID, schema_mgr_, config_);
9292

9393
return std::make_unique<RelAlgExecutor>(executor_.get(), schema_mgr_, std::move(dag));
9494
}

omniscidb/Tests/QueryBuilderTest.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4748,6 +4748,29 @@ TEST_F(QueryBuilderTest, RowidOnResult) {
47484748
}
47494749
}
47504750

4751+
TEST_F(QueryBuilderTest, SqlOnResult) {
4752+
QueryBuilder builder(ctx(), schema_mgr_, configPtr());
4753+
4754+
auto res1 =
4755+
runSqlQuery("SELECT col_bi, col_i FROM test1;", ExecutorDeviceType::CPU, false);
4756+
compare_res_data(res1,
4757+
std::vector<int64_t>({1, 2, 3, 4, 5}),
4758+
std::vector<int32_t>({11, 22, 33, 44, 55}));
4759+
4760+
auto dag = builder.scan(res1.tableName()).proj({1, 0}).finalize();
4761+
auto res2 = runQuery(std::move(dag));
4762+
compare_res_data(res2,
4763+
std::vector<int32_t>({11, 22, 33, 44, 55}),
4764+
std::vector<int64_t>({1, 2, 3, 4, 5}));
4765+
4766+
auto res3 = runSqlQuery("SELECT col_bi + 1, col_i - 1 FROM " + res2.tableName() + ";",
4767+
ExecutorDeviceType::CPU,
4768+
false);
4769+
compare_res_data(res3,
4770+
std::vector<int64_t>({2, 3, 4, 5, 6}),
4771+
std::vector<int32_t>({10, 21, 32, 43, 54}));
4772+
}
4773+
47514774
class Taxi : public TestSuite {
47524775
protected:
47534776
static void SetUpTestSuite() {

python/pyhdk/_sql.pyx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,13 @@ cdef class RelAlgExecutor:
107107
cdef unique_ptr[CQueryDag] c_dag
108108
cdef int db_id = 0
109109

110+
# Choose the default database ID. Ignore ResultSetRegistry.
110111
db_ids = schema_provider.listDatabases()
111-
assert len(db_ids) <= 1
112+
assert len(db_ids) <= 2
112113
if len(db_ids) == 1:
113114
db_id = db_ids[0]
115+
elif len(db_ids) == 2:
116+
db_id = db_ids[1] if db_ids[0] == ((100 << 24) + 1) else db_ids[0]
114117

115118
if ra_json is not None:
116119
c_dag.reset(new CRelAlgDagBuilder(ra_json, db_id, c_schema_provider, c_executor.getConfigPtr()))

python/pyhdk/hdk.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
DataMgr,
1212
SchemaMgr,
1313
)
14-
from pyhdk._sql import Calcite, RelAlgExecutor
14+
from pyhdk._sql import Calcite, RelAlgExecutor, ExecutionResult
1515
from pyhdk._execute import Executor, ResultSetRegistry
1616
from pyhdk._builder import QueryBuilder, QueryExpr, QueryNode
1717

@@ -2035,7 +2035,10 @@ def sql(self, sql_query, query_opts=None, **kwargs):
20352035

20362036
parts = []
20372037
for name, orig_table in kwargs.items():
2038-
if isinstance(orig_table, QueryNode) and orig_table.is_scan:
2038+
if (
2039+
isinstance(orig_table, (QueryNode, ExecutionResult))
2040+
and orig_table.is_scan
2041+
):
20392042
orig_table = orig_table.table_name
20402043
if not isinstance(orig_table, str):
20412044
raise TypeError(
@@ -2050,8 +2053,12 @@ def sql(self, sql_query, query_opts=None, **kwargs):
20502053

20512054
sql_query = "".join(parts) + sql_query
20522055
ra = self._calcite.process(sql_query)
2053-
ra_executor = RelAlgExecutor(self._executor, self._storage, self._data_mgr, ra)
2054-
return ra_executor.execute(**query_opts)
2056+
ra_executor = RelAlgExecutor(
2057+
self._executor, self._schema_mgr, self._data_mgr, ra
2058+
)
2059+
res = ra_executor.execute(**query_opts)
2060+
res.scan = self.scan(res.table_name)
2061+
return res
20552062

20562063
def scan(self, table_name):
20572064
"""

python/tests/test_pyhdk_api.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,21 @@ def test_alias(self):
925925
},
926926
)
927927

928+
def test_run_on_res(self):
929+
hdk = pyhdk.init()
930+
ht1 = hdk.import_pydict(
931+
{"a": [1, 2, 3, 4, 5], "b": [5, 4, 3, 2, 1], "x": [1.1, 2.2, 3.3, 4.4, 5.5]}
932+
)
933+
934+
res1 = hdk.sql("SELECT a, b FROM t1;", t1=ht1)
935+
self.check_res(res1, {"a": [1, 2, 3, 4, 5], "b": [5, 4, 3, 2, 1]})
936+
937+
res2 = hdk.sql("SELECT b + 1 as b, a - 1 as a FROM t1;", t1=res1)
938+
self.check_res(res2, {"b": [6, 5, 4, 3, 2], "a": [0, 1, 2, 3, 4]})
939+
940+
res3 = hdk.sql(f"SELECT b - 1 as b, a + 1 as a FROM {res1.table_name};")
941+
self.check_res(res3, {"b": [4, 3, 2, 1, 0], "a": [2, 3, 4, 5, 6]})
942+
928943

929944
class BaseTaxiTest:
930945
@staticmethod

0 commit comments

Comments
 (0)