Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 1456374

Browse files
committed
Enable execution results usage in the query builder.
Signed-off-by: ienkovich <[email protected]>
1 parent b2b2c3a commit 1456374

19 files changed

+250
-43
lines changed

omniscidb/QueryEngine/Descriptors/RelAlgExecutionDescriptor.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ class ExecutionResult {
5151

5252
bool empty() const { return !result_token_; }
5353

54+
const std::string& tableName() const {
55+
CHECK(!empty());
56+
return result_token_->tableName();
57+
}
58+
5459
const std::vector<TargetMetaInfo>& getTargetsMeta() const { return targets_meta_; }
5560

5661
const std::vector<PushedDownFilterInfo>& getPushedDownFilterInfo() const;

omniscidb/QueryEngine/RelAlgExecutor.cpp

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ RelAlgExecutor::RelAlgExecutor(Executor* executor,
108108
std::unique_ptr<hdk::ir::QueryDag> query_dag)
109109
: executor_(executor)
110110
, query_dag_(std::move(query_dag))
111-
, schema_provider_(std::make_shared<RelAlgSchemaProvider>(*query_dag_->getRootNode()))
111+
, schema_provider_(schema_provider)
112112
, data_provider_(executor->getDataMgr()->getDataProvider())
113113
, config_(executor_->getConfig())
114114
, now_(0)
@@ -118,18 +118,14 @@ RelAlgExecutor::RelAlgExecutor(Executor* executor,
118118

119119
// Add ResultSetRegistry to the schema provider by wrapping the current provider
120120
// and the registry in SchemaMgr.
121-
std::set<int> used_schemas;
122-
for (int db_id : schema_provider_->listDatabases()) {
123-
used_schemas.insert(getSchemaId(db_id));
124-
}
125-
auto schema_mgr = std::make_shared<SchemaMgr>();
126-
for (int schema_id : used_schemas) {
127-
schema_mgr->registerProvider(schema_id, schema_provider_);
128-
}
129-
schema_mgr->registerProvider(
130-
hdk::ResultSetRegistry::SCHEMA_ID,
131-
std::dynamic_pointer_cast<hdk::ResultSetRegistry>(rs_registry_));
132-
schema_provider_ = schema_mgr;
121+
// TODO: In the future we expect pre-initialized registry and passed schema provider
122+
// to cover it.
123+
auto db_ids = schema_provider->listDatabases();
124+
if (std::find(db_ids.begin(), db_ids.end(), hdk::ResultSetRegistry::DB_ID) ==
125+
db_ids.end()) {
126+
schema_provider_ =
127+
mergeProviders(std::vector<SchemaProviderPtr>({schema_provider, rs_registry_}));
128+
}
133129
}
134130

135131
RelAlgExecutor::~RelAlgExecutor() {

omniscidb/ResultSetRegistry/ResultSetRegistry.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ TableFragmentsInfo getEmptyTableMetadata(int table_id) {
4848

4949
} // namespace
5050

51+
ResultSetRegistry::ResultSetRegistry(ConfigPtr config)
52+
: ResultSetRegistry(config, "rs_registry") {}
53+
5154
ResultSetRegistry::ResultSetRegistry(ConfigPtr config,
5255
const std::string& schema_name,
5356
int db_id)
@@ -89,10 +92,11 @@ ResultSetTableTokenPtr ResultSetRegistry::put(ResultSetTable table) {
8992
mapd_unique_lock<mapd_shared_mutex> data_lock(data_mutex_);
9093

9194
auto table_id = next_table_id_++;
95+
auto table_name = std::string("__result_set_") + std::to_string(table_id);
9296
// Add schema information for the ResultSet.
9397
auto tinfo = addTableInfo(db_id_,
9498
table_id,
95-
ResultSetTableToken::tableName(table_id),
99+
table_name,
96100
false,
97101
Data_Namespace::MemoryLevel::CPU_LEVEL,
98102
table.size());
@@ -105,6 +109,7 @@ ResultSetTableTokenPtr ResultSetRegistry::put(ResultSetTable table) {
105109
first_rs->colType(col_idx),
106110
false);
107111
}
112+
addRowidColumn(db_id_, table_id);
108113

109114
// TODO: lazily compute row count and try to avoid global write
110115
// locks for that

omniscidb/ResultSetRegistry/ResultSetRegistry.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,8 @@ class ResultSetRegistry : public SimpleSchemaProvider,
2929
constexpr static int SCHEMA_ID = 100;
3030
constexpr static int DB_ID = (SCHEMA_ID << 24) + 1;
3131

32-
ResultSetRegistry(ConfigPtr config,
33-
const std::string& schema_name = "rs_registry",
34-
int db_id = DB_ID);
32+
ResultSetRegistry(ConfigPtr config);
33+
ResultSetRegistry(ConfigPtr config, const std::string& schema_name, int db_id = DB_ID);
3534

3635
static std::shared_ptr<ResultSetRegistry> getOrCreate(Data_Namespace::DataMgr* data_mgr,
3736
ConfigPtr config);

omniscidb/ResultSetRegistry/ResultSetTableToken.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,7 @@ class ResultSetTableToken {
4242
size_t resultSetCount() const { return tinfo_->fragments; }
4343
ResultSetPtr resultSet(size_t idx) const;
4444

45-
static std::string tableName(int table_id) {
46-
return std::string("__result_set_") + std::to_string(table_id);
47-
}
48-
49-
std::string tableName() const { return tableName(tableId()); }
45+
const std::string& tableName() const { return tinfo_->name; }
5046

5147
std::string toString() const {
5248
return "ResultSetTableToken(" + std::to_string(dbId()) + ":" +

omniscidb/SchemaMgr/SchemaMgr.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,18 @@ const SchemaProvider* SchemaMgr::getMgr(int db_id) const {
8080
}
8181
return nullptr;
8282
}
83+
84+
SchemaProviderPtr mergeProviders(const std::vector<SchemaProviderPtr>& providers) {
85+
auto res = std::make_shared<SchemaMgr>();
86+
for (auto& provider : providers) {
87+
std::set<int> provider_schemas;
88+
for (int db_id : provider->listDatabases()) {
89+
int schema_id = getSchemaId(db_id);
90+
provider_schemas.insert(schema_id);
91+
}
92+
for (int schema_id : provider_schemas) {
93+
res->registerProvider(schema_id, provider);
94+
}
95+
}
96+
return res;
97+
}

omniscidb/SchemaMgr/SchemaMgr.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,5 @@ class SchemaMgr : public SchemaProvider {
3838
};
3939

4040
using SchemaMgrPtr = std::shared_ptr<SchemaMgr>;
41+
42+
SchemaProviderPtr mergeProviders(const std::vector<SchemaProviderPtr>& providers);

omniscidb/Tests/ArrowSQLRunner/ArrowSQLRunner.cpp

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class ArrowSQLRunnerImpl {
9090
auto dag =
9191
std::make_unique<RelAlgDagBuilder>(query_ra, TEST_DB_ID, storage_, config_);
9292

93-
return std::make_unique<RelAlgExecutor>(executor_.get(), storage_, std::move(dag));
93+
return std::make_unique<RelAlgExecutor>(executor_.get(), schema_mgr_, std::move(dag));
9494
}
9595

9696
ExecutionResult runSqlQuery(const std::string& sql,
@@ -290,6 +290,10 @@ class ArrowSQLRunnerImpl {
290290

291291
std::shared_ptr<ArrowStorage> getStorage() { return storage_; }
292292

293+
SchemaProviderPtr getSchemaProvider() { return schema_mgr_; }
294+
295+
std::shared_ptr<hdk::ResultSetRegistry> getResultSetRegistry() { return rs_registry_; }
296+
293297
DataMgr* getDataMgr() { return data_mgr_.get(); }
294298

295299
Executor* getExecutor() { return executor_.get(); }
@@ -299,6 +303,8 @@ class ArrowSQLRunnerImpl {
299303
~ArrowSQLRunnerImpl() {
300304
executor_.reset();
301305
storage_.reset();
306+
rs_registry_.reset();
307+
schema_mgr_.reset();
302308
calcite_.reset();
303309
rel_alg_cache_.reset();
304310

@@ -314,16 +320,21 @@ class ArrowSQLRunnerImpl {
314320
}
315321

316322
storage_ = std::make_shared<ArrowStorage>(TEST_SCHEMA_ID, "test", TEST_DB_ID);
323+
rs_registry_ = std::make_shared<hdk::ResultSetRegistry>(config_);
324+
schema_mgr_ = std::make_shared<SchemaMgr>();
325+
schema_mgr_->registerProvider(TEST_SCHEMA_ID, storage_);
326+
schema_mgr_->registerProvider(hdk::ResultSetRegistry::SCHEMA_ID, rs_registry_);
317327

318328
data_mgr_ = std::make_unique<DataMgr>(*config_);
319329
auto* ps_mgr = data_mgr_->getPersistentStorageMgr();
320330
ps_mgr->registerDataProvider(TEST_SCHEMA_ID, storage_);
331+
ps_mgr->registerDataProvider(hdk::ResultSetRegistry::SCHEMA_ID, rs_registry_);
321332

322333
executor_ = Executor::getExecutor(data_mgr_.get(), config_, "", "");
323-
executor_->setSchemaProvider(storage_);
334+
executor_->setSchemaProvider(schema_mgr_);
324335

325336
if (config_->debug.use_ra_cache.empty() || !config_->debug.build_ra_cache.empty()) {
326-
calcite_ = std::make_shared<CalciteJNI>(storage_, config_, udf_filename, 1024);
337+
calcite_ = std::make_shared<CalciteJNI>(schema_mgr_, config_, udf_filename, 1024);
327338

328339
if (config_->debug.use_ra_cache.empty()) {
329340
ExtensionFunctionsWhitelist::add(calcite_->getExtensionFunctionWhitelist());
@@ -336,13 +347,15 @@ class ArrowSQLRunnerImpl {
336347
calcite_->setRuntimeExtensionFunctions({}, /*is_runtime=*/false);
337348
}
338349

339-
rel_alg_cache_ = std::make_shared<RelAlgCache>(calcite_, storage_, config_);
350+
rel_alg_cache_ = std::make_shared<RelAlgCache>(calcite_, schema_mgr_, config_);
340351
}
341352

342353
ConfigPtr config_;
343354
std::unique_ptr<DataMgr> data_mgr_;
344355
std::shared_ptr<Executor> executor_;
345356
std::shared_ptr<ArrowStorage> storage_;
357+
std::shared_ptr<hdk::ResultSetRegistry> rs_registry_;
358+
std::shared_ptr<SchemaMgr> schema_mgr_;
346359
std::shared_ptr<CalciteJNI> calcite_;
347360
std::shared_ptr<RelAlgCache> rel_alg_cache_;
348361

@@ -491,6 +504,14 @@ std::shared_ptr<ArrowStorage> getStorage() {
491504
return ArrowSQLRunnerImpl::get()->getStorage();
492505
}
493506

507+
SchemaProviderPtr getSchemaProvider() {
508+
return ArrowSQLRunnerImpl::get()->getSchemaProvider();
509+
}
510+
511+
std::shared_ptr<hdk::ResultSetRegistry> getResultSetRegistry() {
512+
return ArrowSQLRunnerImpl::get()->getResultSetRegistry();
513+
}
514+
494515
DataMgr* getDataMgr() {
495516
return ArrowSQLRunnerImpl::get()->getDataMgr();
496517
}

omniscidb/Tests/ArrowSQLRunner/ArrowSQLRunner.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "QueryEngine/ArrowResultSet.h"
2222
#include "QueryEngine/CompilationOptions.h"
2323
#include "QueryEngine/Descriptors/RelAlgExecutionDescriptor.h"
24+
#include "ResultSetRegistry/ResultSetRegistry.h"
2425
#include "Shared/Config.h"
2526

2627
#include "BufferPoolStats.h"
@@ -115,6 +116,10 @@ BufferPoolStats getBufferPoolStats(const Data_Namespace::MemoryLevel memory_leve
115116

116117
std::shared_ptr<ArrowStorage> getStorage();
117118

119+
SchemaProviderPtr getSchemaProvider();
120+
121+
std::shared_ptr<hdk::ResultSetRegistry> getResultSetRegistry();
122+
118123
DataMgr* getDataMgr();
119124

120125
Executor* getExecutor();

omniscidb/Tests/QueryBuilderTest.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,8 @@ class QueryBuilderTest : public TestSuite {
269269
schema_mgr_ = std::make_shared<SchemaMgr>();
270270
schema_mgr_->registerProvider(TEST_SCHEMA_ID, getStorage());
271271
schema_mgr_->registerProvider(TEST_SCHEMA_ID2, storage2_);
272+
schema_mgr_->registerProvider(hdk::ResultSetRegistry::SCHEMA_ID,
273+
getResultSetRegistry());
272274

273275
createTable("test1",
274276
{{"col_bi", ctx().int64()},
@@ -4677,6 +4679,75 @@ TEST_F(Issue355, Reproducer) {
46774679
compare_res_data(res, std::vector<int64_t>({1}), std::vector<int64_t>({12}));
46784680
}
46794681

4682+
TEST_F(QueryBuilderTest, RunOnResult) {
4683+
QueryBuilder builder(ctx(), schema_mgr_, configPtr());
4684+
4685+
{
4686+
auto dag1 = builder.scan("test1").proj({0, 1}).finalize();
4687+
auto res1 = runQuery(std::move(dag1));
4688+
compare_res_data(res1,
4689+
std::vector<int64_t>({1, 2, 3, 4, 5}),
4690+
std::vector<int32_t>({11, 22, 33, 44, 55}));
4691+
4692+
auto dag2 = builder.scan(res1.tableName()).proj({1, 0}).finalize();
4693+
auto res2 = runQuery(std::move(dag2));
4694+
compare_res_data(res2,
4695+
std::vector<int32_t>({11, 22, 33, 44, 55}),
4696+
std::vector<int64_t>({1, 2, 3, 4, 5}));
4697+
4698+
auto scan = builder.scan(res2.tableName());
4699+
auto dag3 = scan.proj({scan["col_i"] + 1, scan["col_bi"] + 2}).finalize();
4700+
auto res3 = runQuery(std::move(dag3));
4701+
compare_res_data(res3,
4702+
std::vector<int32_t>({12, 23, 34, 45, 56}),
4703+
std::vector<int64_t>({3, 4, 5, 6, 7}));
4704+
}
4705+
4706+
{
4707+
auto scan1 = builder.scan("test1");
4708+
auto dag1 = scan1.proj({scan1["col_bi"] + 1, scan1["col_f"]}).finalize();
4709+
auto res1 = runQuery(std::move(dag1));
4710+
compare_res_data(res1,
4711+
std::vector<int64_t>({2, 3, 4, 5, 6}),
4712+
std::vector<float>({1.1, 2.2, 3.3, 4.4, 5.5}));
4713+
4714+
auto scan2 = builder.scan("test1");
4715+
auto dag2 = scan2.proj({scan2["col_bi"] + 2, scan2["col_d"]}).finalize();
4716+
auto res2 = runQuery(std::move(dag2));
4717+
compare_res_data(res2,
4718+
std::vector<int64_t>({3, 4, 5, 6, 7}),
4719+
std::vector<double>({11.11, 22.22, 33.33, 44.44, 55.55}));
4720+
4721+
auto dag3 =
4722+
builder.scan(res1.tableName()).join(builder.scan(res2.tableName())).finalize();
4723+
auto res3 = runQuery(std::move(dag3));
4724+
compare_res_data(res3,
4725+
std::vector<int64_t>({3, 4, 5, 6}),
4726+
std::vector<float>({2.2, 3.3, 4.4, 5.5}),
4727+
std::vector<double>({11.11, 22.22, 33.33, 44.44}));
4728+
}
4729+
}
4730+
4731+
TEST_F(QueryBuilderTest, RowidOnResult) {
4732+
QueryBuilder builder(ctx(), schema_mgr_, configPtr());
4733+
4734+
{
4735+
auto dag1 = builder.scan("test1").proj({0, 1}).finalize();
4736+
auto res1 = runQuery(std::move(dag1));
4737+
compare_res_data(res1,
4738+
std::vector<int64_t>({1, 2, 3, 4, 5}),
4739+
std::vector<int32_t>({11, 22, 33, 44, 55}));
4740+
4741+
auto dag2 =
4742+
builder.scan(res1.tableName()).proj({"col_i", "col_bi", "rowid"}).finalize();
4743+
auto res2 = runQuery(std::move(dag2));
4744+
compare_res_data(res2,
4745+
std::vector<int32_t>({11, 22, 33, 44, 55}),
4746+
std::vector<int64_t>({1, 2, 3, 4, 5}),
4747+
std::vector<int64_t>({0, 1, 2, 3, 4}));
4748+
}
4749+
}
4750+
46804751
class Taxi : public TestSuite {
46814752
protected:
46824753
static void SetUpTestSuite() {

0 commit comments

Comments
 (0)