Skip to content

Commit 671897e

Browse files
authored
refactor(codegen): null safe for struct ir builder (#3547)
fixes #926
1 parent bba4e51 commit 671897e

18 files changed

+197
-173
lines changed

cases/query/const_query.yaml

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,15 +126,55 @@ cases:
126126
columns: ['c1 bool', 'c2 int16', 'c3 int', 'c4 double', 'c5 string', 'c6 date', 'c7 timestamp' ]
127127
rows:
128128
- [ true, 3, 13, 10.0, 'a string', '2020-05-22', 1590115420000 ]
129+
130+
# =================================================================================
131+
# Null safe for structure types: String, Date, Timestamp and Array
132+
# creating struct from:
133+
# 1. NULL liternal (const null)
134+
# 2. another supported date type but fails to cast, e.g. timestamp(-1) returns NULL of timestamp
135+
#
136+
# casting to array type un-implemented
137+
# =================================================================================
129138
- id: 10
139+
desc: null safe for date
130140
mode: procedure-unsupport
131141
sql: |
132142
select
133143
datediff(Date(timestamp(-1)), Date("2021-05-01")) as out1,
134144
datediff(Date(timestamp(-2177481600)), Date("2021-05-01")) as out2,
135-
datediff(cast(NULL as date), Date("2021-05-01")) as out3
136-
;
145+
datediff(cast(NULL as date), Date("2021-05-01")) as out3,
146+
date(NULL) as out4,
147+
date("abc") as out5,
148+
date(timestamp("abc")) as out6
149+
expect:
150+
columns: ["out1 int", "out2 int", "out3 int", "out4 date", "out5 date", "out6 date"]
151+
data: |
152+
NULL, NULL, NULL, NULL, NULL, NULL
153+
- id: 11
154+
desc: null safe for timestamp
155+
mode: procedure-unsupport
156+
sql: |
157+
select
158+
month(cast(NULL as timestamp)) as out1,
159+
month(timestamp(NULL)) as out2,
160+
month(timestamp(-1)) as out3,
161+
month(timestamp("abc")) as out4,
162+
month(timestamp(date("abc"))) as out5
163+
expect:
164+
columns: ["out1 int", "out2 int", "out3 int", "out4 int", "out5 int"]
165+
data: |
166+
NULL, NULL, NULL, NULL, NULL
167+
- id: 12
168+
desc: null safe for string
169+
mode: procedure-unsupport
170+
sql: |
171+
select
172+
char_length(cast(NULL as string)) as out1,
173+
char_length(string(int(NULL))) as out2,
174+
char_length(string(bool(null))) as out3,
175+
char_length(string(timestamp(null))) as out4,
176+
char_length(string(date(null))) as out5
137177
expect:
138-
columns: ["out1 int", "out2 int", "out3 int"]
178+
columns: ["out1 int", "out2 int", "out3 int", "out4 int", "out5 int"]
139179
data: |
140-
NULL, NULL, NULL
180+
NULL, NULL, NULL, NULL, NULL

hybridse/src/codegen/array_ir_builder.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,5 +114,21 @@ base::Status ArrayIRBuilder::NewEmptyArray(llvm::BasicBlock* bb, NativeValue* ou
114114
return base::Status::OK();
115115
}
116116

117+
bool ArrayIRBuilder::CreateDefault(::llvm::BasicBlock* block, ::llvm::Value** output) {
118+
llvm::Value* array_alloca = nullptr;
119+
if (!Create(block, &array_alloca)) {
120+
return false;
121+
}
122+
123+
llvm::IRBuilder<> builder(block);
124+
::llvm::Value* array_sz = builder.getInt64(0);
125+
if (!Set(block, array_alloca, 2, array_sz)) {
126+
return false;
127+
}
128+
129+
*output = array_alloca;
130+
return true;
131+
}
132+
117133
} // namespace codegen
118134
} // namespace hybridse

hybridse/src/codegen/array_ir_builder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,12 @@ class ArrayIRBuilder : public StructTypeIRBuilder {
4949

5050
void InitStructType() override;
5151

52-
bool CreateDefault(::llvm::BasicBlock* block, ::llvm::Value** output) override { return true; }
52+
bool CreateDefault(::llvm::BasicBlock* block, ::llvm::Value** output) override;
5353

5454
bool CopyFrom(::llvm::BasicBlock* block, ::llvm::Value* src, ::llvm::Value* dist) override { return true; }
5555

5656
base::Status CastFrom(::llvm::BasicBlock* block, const NativeValue& src, NativeValue* output) override {
57-
return base::Status::OK();
57+
CHECK_TRUE(false, common::kCodegenError, "casting to array un-implemented");
5858
};
5959

6060
private:

hybridse/src/codegen/cast_expr_ir_builder.cc

Lines changed: 32 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,15 @@
1515
*/
1616

1717
#include "codegen/cast_expr_ir_builder.h"
18+
1819
#include "codegen/date_ir_builder.h"
1920
#include "codegen/ir_base_builder.h"
2021
#include "codegen/string_ir_builder.h"
2122
#include "codegen/timestamp_ir_builder.h"
23+
#include "codegen/type_ir_builder.h"
2224
#include "glog/logging.h"
2325
#include "node/node_manager.h"
26+
#include "proto/fe_common.pb.h"
2427

2528
using hybridse::common::kCodegenError;
2629

@@ -72,98 +75,73 @@ Status CastExprIRBuilder::Cast(const NativeValue& value,
7275
}
7376
return Status::OK();
7477
}
75-
Status CastExprIRBuilder::SafeCast(const NativeValue& value, ::llvm::Type* type,
76-
NativeValue* output) {
78+
79+
Status CastExprIRBuilder::SafeCast(const NativeValue& value, ::llvm::Type* dst_type, NativeValue* output) {
7780
::llvm::IRBuilder<> builder(block_);
78-
CHECK_TRUE(IsSafeCast(value.GetType(), type), kCodegenError,
79-
"Safe cast fail: unsafe cast");
81+
CHECK_TRUE(IsSafeCast(value.GetType(), dst_type), kCodegenError, "Safe cast fail: unsafe cast");
8082
Status status;
8183
if (value.IsConstNull()) {
82-
if (TypeIRBuilder::IsStringPtr(type)) {
83-
StringIRBuilder string_ir_builder(block_->getModule());
84-
CHECK_STATUS(string_ir_builder.CreateNull(block_, output));
85-
return base::Status::OK();
86-
} else {
87-
*output = NativeValue::CreateNull(type);
88-
}
89-
} else if (TypeIRBuilder::IsTimestampPtr(type)) {
84+
auto res = CreateSafeNull(block_, dst_type);
85+
CHECK_TRUE(res.ok(), kCodegenError, res.status().ToString());
86+
*output = res.value();
87+
} else if (TypeIRBuilder::IsTimestampPtr(dst_type)) {
9088
TimestampIRBuilder timestamp_ir_builder(block_->getModule());
9189
CHECK_STATUS(timestamp_ir_builder.CastFrom(block_, value, output));
9290
return Status::OK();
93-
} else if (TypeIRBuilder::IsDatePtr(type)) {
91+
} else if (TypeIRBuilder::IsDatePtr(dst_type)) {
9492
DateIRBuilder date_ir_builder(block_->getModule());
9593
CHECK_STATUS(date_ir_builder.CastFrom(block_, value, output));
9694
return Status::OK();
97-
} else if (TypeIRBuilder::IsStringPtr(type)) {
95+
} else if (TypeIRBuilder::IsStringPtr(dst_type)) {
9896
StringIRBuilder string_ir_builder(block_->getModule());
9997
CHECK_STATUS(string_ir_builder.CastFrom(block_, value, output));
10098
return Status::OK();
101-
} else if (TypeIRBuilder::IsNumber(type)) {
99+
} else if (TypeIRBuilder::IsNumber(dst_type)) {
102100
Status status;
103101
::llvm::Value* output_value = nullptr;
104-
CHECK_TRUE(SafeCastNumber(value.GetValue(&builder), type, &output_value,
105-
status),
106-
kCodegenError);
102+
CHECK_TRUE(SafeCastNumber(value.GetValue(&builder), dst_type, &output_value, status), kCodegenError);
107103
if (value.IsNullable()) {
108-
*output = NativeValue::CreateWithFlag(output_value,
109-
value.GetIsNull(&builder));
104+
*output = NativeValue::CreateWithFlag(output_value, value.GetIsNull(&builder));
110105
} else {
111106
*output = NativeValue::Create(output_value);
112107
}
113108
} else {
114-
return Status(common::kCodegenError,
115-
"Can't cast from " +
116-
TypeIRBuilder::TypeName(value.GetType()) + " to " +
117-
TypeIRBuilder::TypeName(type));
109+
return Status(common::kCodegenError, "Can't cast from " + TypeIRBuilder::TypeName(value.GetType()) + " to " +
110+
TypeIRBuilder::TypeName(dst_type));
118111
}
119112
return Status::OK();
120113
}
121-
Status CastExprIRBuilder::UnSafeCast(const NativeValue& value,
122-
::llvm::Type* type, NativeValue* output) {
123-
::llvm::IRBuilder<> builder(block_);
124-
if (value.IsConstNull()) {
125-
if (TypeIRBuilder::IsStringPtr(type)) {
126-
StringIRBuilder string_ir_builder(block_->getModule());
127-
CHECK_STATUS(string_ir_builder.CreateNull(block_, output));
128-
return base::Status::OK();
129114

130-
} else if (TypeIRBuilder::IsDatePtr(type)) {
131-
DateIRBuilder date_ir(block_->getModule());
132-
CHECK_STATUS(date_ir.CreateNull(block_, output));
133-
return base::Status::OK();
134-
} else {
135-
*output = NativeValue::CreateNull(type);
136-
}
137-
} else if (TypeIRBuilder::IsTimestampPtr(type)) {
115+
Status CastExprIRBuilder::UnSafeCast(const NativeValue& value, ::llvm::Type* dst_type, NativeValue* output) {
116+
::llvm::IRBuilder<> builder(block_);
117+
if (value.IsConstNull() || (TypeIRBuilder::IsNumber(dst_type) && TypeIRBuilder::IsDatePtr(value.GetType()))) {
118+
// input is const null or (cast date to number)
119+
auto res = CreateSafeNull(block_, dst_type);
120+
CHECK_TRUE(res.ok(), kCodegenError, res.status().ToString());
121+
*output = res.value();
122+
} else if (TypeIRBuilder::IsTimestampPtr(dst_type)) {
138123
TimestampIRBuilder timestamp_ir_builder(block_->getModule());
139124
CHECK_STATUS(timestamp_ir_builder.CastFrom(block_, value, output));
140125
return Status::OK();
141-
} else if (TypeIRBuilder::IsDatePtr(type)) {
126+
} else if (TypeIRBuilder::IsDatePtr(dst_type)) {
142127
DateIRBuilder date_ir_builder(block_->getModule());
143128
CHECK_STATUS(date_ir_builder.CastFrom(block_, value, output));
144129
return Status::OK();
145-
} else if (TypeIRBuilder::IsStringPtr(type)) {
130+
} else if (TypeIRBuilder::IsStringPtr(dst_type)) {
146131
StringIRBuilder string_ir_builder(block_->getModule());
147132
CHECK_STATUS(string_ir_builder.CastFrom(block_, value, output));
148133
return Status::OK();
149-
} else if (TypeIRBuilder::IsNumber(type) &&
150-
TypeIRBuilder::IsStringPtr(value.GetType())) {
134+
} else if (TypeIRBuilder::IsNumber(dst_type) && TypeIRBuilder::IsStringPtr(value.GetType())) {
151135
StringIRBuilder string_ir_builder(block_->getModule());
152-
CHECK_STATUS(
153-
string_ir_builder.CastToNumber(block_, value, type, output));
136+
CHECK_STATUS(string_ir_builder.CastToNumber(block_, value, dst_type, output));
154137
return Status::OK();
155-
} else if (TypeIRBuilder::IsNumber(type) &&
156-
TypeIRBuilder::IsDatePtr(value.GetType())) {
157-
*output = NativeValue::CreateNull(type);
158138
} else {
159139
Status status;
160140
::llvm::Value* output_value = nullptr;
161-
CHECK_TRUE(UnSafeCastNumber(value.GetValue(&builder), type,
162-
&output_value, status),
163-
kCodegenError, status.msg);
141+
CHECK_TRUE(UnSafeCastNumber(value.GetValue(&builder), dst_type, &output_value, status), kCodegenError,
142+
status.msg);
164143
if (value.IsNullable()) {
165-
*output = NativeValue::CreateWithFlag(output_value,
166-
value.GetIsNull(&builder));
144+
*output = NativeValue::CreateWithFlag(output_value, value.GetIsNull(&builder));
167145
} else {
168146
*output = NativeValue::Create(output_value);
169147
}

hybridse/src/codegen/cast_expr_ir_builder.h

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@
1818
#define HYBRIDSE_SRC_CODEGEN_CAST_EXPR_IR_BUILDER_H_
1919
#include "base/fe_status.h"
2020
#include "codegen/cond_select_ir_builder.h"
21-
#include "codegen/scope_var.h"
22-
#include "llvm/IR/IRBuilder.h"
23-
#include "proto/fe_type.pb.h"
2421

2522
namespace hybridse {
2623
namespace codegen {
@@ -32,26 +29,19 @@ class CastExprIRBuilder {
3229
explicit CastExprIRBuilder(::llvm::BasicBlock* block);
3330
~CastExprIRBuilder();
3431

35-
Status Cast(const NativeValue& value, ::llvm::Type* cast_type,
36-
NativeValue* output); // NOLINT
37-
Status SafeCast(const NativeValue& value, ::llvm::Type* type,
38-
NativeValue* output); // NOLINT
39-
Status UnSafeCast(const NativeValue& value, ::llvm::Type* type,
40-
NativeValue* output); // NOLINT
32+
Status Cast(const NativeValue& value, ::llvm::Type* cast_type, NativeValue* output);
33+
Status SafeCast(const NativeValue& value, ::llvm::Type* dst_type, NativeValue* output);
34+
Status UnSafeCast(const NativeValue& value, ::llvm::Type* dst_type, NativeValue* output);
4135
static bool IsSafeCast(::llvm::Type* lhs, ::llvm::Type* rhs);
42-
static Status InferNumberCastTypes(::llvm::Type* left_type,
43-
::llvm::Type* right_type);
36+
static Status InferNumberCastTypes(::llvm::Type* left_type, ::llvm::Type* right_type);
4437
static bool IsIntFloat2PointerCast(::llvm::Type* src, ::llvm::Type* dist);
4538
bool BoolCast(llvm::Value* pValue, llvm::Value** pValue1,
4639
base::Status& status); // NOLINT
47-
bool SafeCastNumber(::llvm::Value* value, ::llvm::Type* type,
48-
::llvm::Value** output,
40+
bool SafeCastNumber(::llvm::Value* value, ::llvm::Type* type, ::llvm::Value** output,
4941
base::Status& status); // NOLINT
50-
bool UnSafeCastNumber(::llvm::Value* value, ::llvm::Type* type,
51-
::llvm::Value** output,
42+
bool UnSafeCastNumber(::llvm::Value* value, ::llvm::Type* type, ::llvm::Value** output,
5243
base::Status& status); // NOLINT
53-
bool UnSafeCastDouble(::llvm::Value* value, ::llvm::Type* type,
54-
::llvm::Value** output,
44+
bool UnSafeCastDouble(::llvm::Value* value, ::llvm::Type* type, ::llvm::Value** output,
5545
base::Status& status); // NOLINT
5646

5747
private:

hybridse/src/codegen/date_ir_builder.cc

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,6 @@ void DateIRBuilder::InitStructType() {
4545
return;
4646
}
4747

48-
base::Status DateIRBuilder::CreateNull(::llvm::BasicBlock* block, NativeValue* output) {
49-
::llvm::Value* value = nullptr;
50-
CHECK_TRUE(CreateDefault(block, &value), common::kCodegenError, "Fail to construct string")
51-
::llvm::IRBuilder<> builder(block);
52-
*output = NativeValue::CreateWithFlag(value, builder.getInt1(true));
53-
return base::Status::OK();
54-
}
55-
5648
bool DateIRBuilder::CreateDefault(::llvm::BasicBlock* block,
5749
::llvm::Value** output) {
5850
return NewDate(block, output);

hybridse/src/codegen/date_ir_builder.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,14 @@ class DateIRBuilder : public StructTypeIRBuilder {
2727
public:
2828
explicit DateIRBuilder(::llvm::Module* m);
2929
~DateIRBuilder();
30-
void InitStructType() override;
3130

32-
base::Status CreateNull(::llvm::BasicBlock* block, NativeValue* output);
31+
void InitStructType() override;
3332
bool CreateDefault(::llvm::BasicBlock* block, ::llvm::Value** output) override;
33+
bool CopyFrom(::llvm::BasicBlock* block, ::llvm::Value* src, ::llvm::Value* dist) override;
34+
base::Status CastFrom(::llvm::BasicBlock* block, const NativeValue& src, NativeValue* output) override;
3435

3536
bool NewDate(::llvm::BasicBlock* block, ::llvm::Value** output);
36-
bool NewDate(::llvm::BasicBlock* block, ::llvm::Value* date,
37-
::llvm::Value** output);
38-
bool CopyFrom(::llvm::BasicBlock* block, ::llvm::Value* src, ::llvm::Value* dist);
39-
base::Status CastFrom(::llvm::BasicBlock* block, const NativeValue& src, NativeValue* output);
37+
bool NewDate(::llvm::BasicBlock* block, ::llvm::Value* date, ::llvm::Value** output);
4038

4139
bool GetDate(::llvm::BasicBlock* block, ::llvm::Value* date,
4240
::llvm::Value** output);

hybridse/src/codegen/expr_ir_builder.cc

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,8 @@
2626
#include "codegen/cond_select_ir_builder.h"
2727
#include "codegen/context.h"
2828
#include "codegen/date_ir_builder.h"
29-
#include "codegen/fn_ir_builder.h"
3029
#include "codegen/ir_base_builder.h"
3130
#include "codegen/list_ir_builder.h"
32-
#include "codegen/struct_ir_builder.h"
3331
#include "codegen/timestamp_ir_builder.h"
3432
#include "codegen/type_ir_builder.h"
3533
#include "codegen/udf_ir_builder.h"
@@ -217,8 +215,7 @@ Status ExprIRBuilder::BuildConstExpr(
217215
::llvm::IRBuilder<> builder(ctx_->GetCurrentBlock());
218216
switch (const_node->GetDataType()) {
219217
case ::hybridse::node::kNull: {
220-
*output = NativeValue::CreateNull(
221-
llvm::Type::getTokenTy(builder.getContext()));
218+
*output = NativeValue(nullptr, nullptr, llvm::Type::getTokenTy(builder.getContext()));
222219
break;
223220
}
224221
case ::hybridse::node::kBool: {

hybridse/src/codegen/ir_base_builder.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include "codegen/ir_base_builder.h"
1818

1919
#include <string>
20-
#include <tuple>
2120
#include <utility>
2221
#include <vector>
2322

@@ -625,21 +624,25 @@ bool GetBaseType(::llvm::Type* type, ::hybridse::node::DataType* output) {
625624
return false;
626625
}
627626

628-
if (pointee_ty->getStructName().startswith("fe.list_ref_")) {
627+
auto struct_name = pointee_ty->getStructName();
628+
if (struct_name.startswith("fe.list_ref_")) {
629629
*output = hybridse::node::kList;
630630
return true;
631-
} else if (pointee_ty->getStructName().startswith("fe.iterator_ref_")) {
631+
} else if (struct_name.startswith("fe.iterator_ref_")) {
632632
*output = hybridse::node::kIterator;
633633
return true;
634-
} else if (pointee_ty->getStructName().equals("fe.string_ref")) {
634+
} else if (struct_name.equals("fe.string_ref")) {
635635
*output = hybridse::node::kVarchar;
636636
return true;
637-
} else if (pointee_ty->getStructName().equals("fe.timestamp")) {
637+
} else if (struct_name.equals("fe.timestamp")) {
638638
*output = hybridse::node::kTimestamp;
639639
return true;
640-
} else if (pointee_ty->getStructName().equals("fe.date")) {
640+
} else if (struct_name.equals("fe.date")) {
641641
*output = hybridse::node::kDate;
642642
return true;
643+
} else if (struct_name.startswith("fe.array_")) {
644+
*output = hybridse::node::kArray;
645+
return true;
643646
}
644647
LOG(WARNING) << "no mapping pointee_ty for llvm pointee_ty "
645648
<< pointee_ty->getStructName().str();

0 commit comments

Comments
 (0)