Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 629dd98

Browse files
committed
Refactor array loops generation.
Signed-off-by: ienkovich <[email protected]>
1 parent 5fbdfa2 commit 629dd98

File tree

5 files changed

+109
-60
lines changed

5 files changed

+109
-60
lines changed

omniscidb/IR/ExprCollector.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ class ExprCollector : public ExprVisitor<void> {
3434
return std::move(collector.result_);
3535
}
3636

37+
template <typename... Ts>
38+
static ResultType collect(const std::vector<const Expr*>& exprs, Ts&&... args) {
39+
CollectorType collector(std::forward<Ts>(args)...);
40+
for (auto& expr : exprs) {
41+
collector.visit(expr);
42+
}
43+
return std::move(collector.result_);
44+
}
45+
3746
ResultType& result() { return result_; }
3847
const ResultType& result() const { return result_; }
3948

omniscidb/QueryEngine/ArrayIR.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
llvm::Value* CodeGenerator::codegenUnnest(const hdk::ir::UOper* uoper,
2121
const CompilationOptions& co) {
2222
AUTOMATIC_IR_METADATA(cgen_state_);
23-
return codegen(uoper->operand(), true, co).front();
23+
auto array_lv = codegen(uoper->operand(), true, co).front();
24+
if (!cgen_state_->unnest_cache_.count(array_lv)) {
25+
throw std::runtime_error("Unsupported context for UNNEST operation.");
26+
}
27+
return cgen_state_->unnest_cache_.at(array_lv);
2428
}
2529

2630
llvm::Value* CodeGenerator::codegenArrayAt(const hdk::ir::BinOper* array_at,

omniscidb/QueryEngine/CgenState.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ struct CgenState {
389389
str_dict_translation_mgrs_;
390390
std::map<std::pair<llvm::Value*, llvm::Value*>, ArrayLoadCodegen>
391391
array_load_cache_; // byte stream to array info
392+
std::unordered_map<llvm::Value*, llvm::Value*> unnest_cache_;
392393
bool needs_error_check_;
393394
bool automatic_ir_metadata_;
394395

omniscidb/QueryEngine/Execute.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,10 @@ class Executor : public StringDictionaryProxyProvider {
821821
DiamondCodegen&,
822822
std::stack<llvm::BasicBlock*>&,
823823
const bool thread_mem_shared);
824+
llvm::Value* arrayLoopCodegen(const hdk::ir::Expr* array_expr,
825+
std::stack<llvm::BasicBlock*>& array_loops,
826+
DiamondCodegen& diamond_codegen,
827+
const CompilationOptions& co);
824828

825829
llvm::Value* castToFP(llvm::Value*,
826830
const hdk::ir::Type* from_type,

omniscidb/QueryEngine/IRCodegen.cpp

Lines changed: 90 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,83 +1169,32 @@ Executor::GroupColLLVMValue Executor::groupByColumnCodegen(
11691169
const bool thread_mem_shared) {
11701170
AUTOMATIC_IR_METADATA(cgen_state_.get());
11711171
CHECK_GE(col_width, sizeof(int32_t));
1172+
llvm::Value* group_key;
1173+
llvm::Value* key_to_cache;
11721174
CodeGenerator code_generator(this, co.codegen_traits_desc);
1173-
auto group_key = code_generator.codegen(group_by_col, true, co).front();
1174-
auto key_to_cache = group_key;
11751175
if (group_by_col && group_by_col->is<hdk::ir::UOper>() &&
11761176
group_by_col->as<hdk::ir::UOper>()->isUnnest()) {
1177-
auto preheader = cgen_state_->ir_builder_.GetInsertBlock();
1178-
auto array_loop_head = llvm::BasicBlock::Create(cgen_state_->context_,
1179-
"array_loop_head",
1180-
cgen_state_->current_func_,
1181-
preheader->getNextNode());
1182-
diamond_codegen.setFalseTarget(array_loop_head);
1183-
const auto ret_ty = get_int_type(32, cgen_state_->context_);
1184-
llvm::Value* array_idx_ptr = cgen_state_->ir_builder_.CreateAlloca(ret_ty);
1185-
if (array_idx_ptr->getType()->getPointerAddressSpace() !=
1186-
co.codegen_traits_desc.local_addr_space_) {
1187-
array_idx_ptr = cgen_state_->ir_builder_.CreateAddrSpaceCast(
1188-
array_idx_ptr,
1189-
llvm::PointerType::get(array_idx_ptr->getType()->getPointerElementType(),
1190-
co.codegen_traits_desc.local_addr_space_),
1191-
"array.idx.ptrcast");
1192-
}
1193-
CHECK(array_idx_ptr);
1194-
cgen_state_->ir_builder_.CreateStore(cgen_state_->llInt(int32_t(0)), array_idx_ptr);
11951177
const auto arr_expr = group_by_col->as<hdk::ir::UOper>()->operand();
11961178
auto array_type = arr_expr->type();
11971179
CHECK(array_type->isArray());
11981180
auto elem_type = array_type->as<hdk::ir::ArrayBaseType>()->elemType();
1199-
auto array_len =
1200-
(array_type->size() > 0)
1201-
? cgen_state_->llInt(array_type->size() / elem_type->size())
1202-
: cgen_state_->emitExternalCall(
1203-
"array_size",
1204-
ret_ty,
1205-
{group_key,
1206-
code_generator.posArg(arr_expr),
1207-
cgen_state_->llInt(log2_bytes(elem_type->canonicalSize()))});
1208-
cgen_state_->ir_builder_.CreateBr(array_loop_head);
1209-
cgen_state_->ir_builder_.SetInsertPoint(array_loop_head);
1210-
CHECK(array_len);
1211-
auto array_idx = cgen_state_->ir_builder_.CreateLoad(
1212-
array_idx_ptr->getType()->getPointerElementType(), array_idx_ptr);
1213-
auto bound_check = cgen_state_->ir_builder_.CreateICmp(
1214-
llvm::ICmpInst::ICMP_SLT, array_idx, array_len);
1215-
auto array_loop_body = llvm::BasicBlock::Create(
1216-
cgen_state_->context_, "array_loop_body", cgen_state_->current_func_);
1217-
cgen_state_->ir_builder_.CreateCondBr(
1218-
bound_check,
1219-
array_loop_body,
1220-
array_loops.empty() ? diamond_codegen.orig_cond_false_ : array_loops.top());
1221-
cgen_state_->ir_builder_.SetInsertPoint(array_loop_body);
1222-
cgen_state_->ir_builder_.CreateStore(
1223-
cgen_state_->ir_builder_.CreateAdd(array_idx, cgen_state_->llInt(int32_t(1))),
1224-
array_idx_ptr);
1225-
auto array_at_fname = "array_at_" + numeric_type_name(elem_type);
1226-
if (array_type->size() < 0) {
1227-
if (!array_type->nullable()) {
1228-
array_at_fname = "notnull_" + array_at_fname;
1229-
}
1230-
array_at_fname = "varlen_" + array_at_fname;
1231-
}
12321181
const auto ar_ret_ty =
12331182
elem_type->isFloatingPoint()
12341183
? (elem_type->isFp64() ? llvm::Type::getDoubleTy(cgen_state_->context_)
12351184
: llvm::Type::getFloatTy(cgen_state_->context_))
12361185
: get_int_type(elem_type->canonicalSize() * 8, cgen_state_->context_);
1237-
group_key = cgen_state_->emitExternalCall(
1238-
array_at_fname,
1239-
ar_ret_ty,
1240-
{group_key, code_generator.posArg(arr_expr), array_idx});
1186+
1187+
group_key = arrayLoopCodegen(arr_expr, array_loops, diamond_codegen, co);
1188+
12411189
if (need_patch_unnest_double(
12421190
elem_type, isArchMaxwell(co.device_type), thread_mem_shared)) {
12431191
key_to_cache = spillDoubleElement(group_key, ar_ret_ty);
12441192
} else {
12451193
key_to_cache = group_key;
12461194
}
1247-
CHECK(array_loop_head);
1248-
array_loops.push(array_loop_head);
1195+
} else {
1196+
group_key = code_generator.codegen(group_by_col, true, co).front();
1197+
key_to_cache = group_key;
12491198
}
12501199
cgen_state_->group_by_expr_cache_.push_back(key_to_cache);
12511200
llvm::Value* orig_group_key{nullptr};
@@ -1274,6 +1223,88 @@ Executor::GroupColLLVMValue Executor::groupByColumnCodegen(
12741223
return {group_key, orig_group_key};
12751224
}
12761225

1226+
llvm::Value* Executor::arrayLoopCodegen(const hdk::ir::Expr* array_expr,
1227+
std::stack<llvm::BasicBlock*>& array_loops,
1228+
DiamondCodegen& diamond_codegen,
1229+
const CompilationOptions& co) {
1230+
AUTOMATIC_IR_METADATA(cgen_state_.get());
1231+
CodeGenerator code_generator(this, co.codegen_traits_desc);
1232+
auto array_lv = code_generator.codegen(array_expr, true, co).front();
1233+
1234+
if (cgen_state_->unnest_cache_.count(array_lv)) {
1235+
return cgen_state_->unnest_cache_.at(array_lv);
1236+
}
1237+
1238+
auto preheader = cgen_state_->ir_builder_.GetInsertBlock();
1239+
auto array_loop_head = llvm::BasicBlock::Create(cgen_state_->context_,
1240+
"array_loop_head",
1241+
cgen_state_->current_func_,
1242+
preheader->getNextNode());
1243+
diamond_codegen.setFalseTarget(array_loop_head);
1244+
const auto ret_ty = get_int_type(32, cgen_state_->context_);
1245+
llvm::Value* array_idx_ptr = cgen_state_->ir_builder_.CreateAlloca(ret_ty);
1246+
if (array_idx_ptr->getType()->getPointerAddressSpace() !=
1247+
co.codegen_traits_desc.local_addr_space_) {
1248+
array_idx_ptr = cgen_state_->ir_builder_.CreateAddrSpaceCast(
1249+
array_idx_ptr,
1250+
llvm::PointerType::get(array_idx_ptr->getType()->getPointerElementType(),
1251+
co.codegen_traits_desc.local_addr_space_),
1252+
"array.idx.ptrcast");
1253+
}
1254+
CHECK(array_idx_ptr);
1255+
cgen_state_->ir_builder_.CreateStore(cgen_state_->llInt(int32_t(0)), array_idx_ptr);
1256+
auto array_type = array_expr->type();
1257+
CHECK(array_type->isArray());
1258+
auto elem_type = array_type->as<hdk::ir::ArrayBaseType>()->elemType();
1259+
auto array_len =
1260+
(array_type->size() > 0)
1261+
? cgen_state_->llInt(array_type->size() / elem_type->size())
1262+
: cgen_state_->emitExternalCall(
1263+
"array_size",
1264+
ret_ty,
1265+
{array_lv,
1266+
code_generator.posArg(array_expr),
1267+
cgen_state_->llInt(log2_bytes(elem_type->canonicalSize()))});
1268+
cgen_state_->ir_builder_.CreateBr(array_loop_head);
1269+
cgen_state_->ir_builder_.SetInsertPoint(array_loop_head);
1270+
CHECK(array_len);
1271+
auto array_idx = cgen_state_->ir_builder_.CreateLoad(
1272+
array_idx_ptr->getType()->getPointerElementType(), array_idx_ptr);
1273+
auto bound_check =
1274+
cgen_state_->ir_builder_.CreateICmp(llvm::ICmpInst::ICMP_SLT, array_idx, array_len);
1275+
auto array_loop_body = llvm::BasicBlock::Create(
1276+
cgen_state_->context_, "array_loop_body", cgen_state_->current_func_);
1277+
cgen_state_->ir_builder_.CreateCondBr(
1278+
bound_check,
1279+
array_loop_body,
1280+
array_loops.empty() ? diamond_codegen.orig_cond_false_ : array_loops.top());
1281+
cgen_state_->ir_builder_.SetInsertPoint(array_loop_body);
1282+
array_loops.push(array_loop_head);
1283+
cgen_state_->ir_builder_.CreateStore(
1284+
cgen_state_->ir_builder_.CreateAdd(array_idx, cgen_state_->llInt(int32_t(1))),
1285+
array_idx_ptr);
1286+
auto array_at_fname = "array_at_" + numeric_type_name(elem_type);
1287+
if (array_type->size() < 0) {
1288+
if (!array_type->nullable()) {
1289+
array_at_fname = "notnull_" + array_at_fname;
1290+
}
1291+
array_at_fname = "varlen_" + array_at_fname;
1292+
}
1293+
const auto ar_ret_ty =
1294+
elem_type->isFloatingPoint()
1295+
? (elem_type->isFp64() ? llvm::Type::getDoubleTy(cgen_state_->context_)
1296+
: llvm::Type::getFloatTy(cgen_state_->context_))
1297+
: get_int_type(elem_type->canonicalSize() * 8, cgen_state_->context_);
1298+
auto res = cgen_state_->emitExternalCall(
1299+
array_at_fname,
1300+
ar_ret_ty,
1301+
{array_lv, code_generator.posArg(array_expr), array_idx});
1302+
1303+
cgen_state_->unnest_cache_.emplace(array_lv, res);
1304+
1305+
return res;
1306+
}
1307+
12771308
CodeGenerator::NullCheckCodegen::NullCheckCodegen(CgenState* cgen_state,
12781309
Executor* executor,
12791310
llvm::Value* nullable_lv,

0 commit comments

Comments
 (0)