Skip to content

Commit d44a534

Browse files
authored
Better handling for Union-type fields, particularly of singletons (#43163)
fix #43123
1 parent ae336ab commit d44a534

File tree

4 files changed

+87
-79
lines changed

4 files changed

+87
-79
lines changed

src/cgutils.cpp

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2146,22 +2146,25 @@ static bool emit_getfield_unknownidx(jl_codectx_t &ctx,
21462146
return false;
21472147
}
21482148

2149-
static jl_cgval_t emit_unionload(jl_codectx_t &ctx, Value *addr, Value *ptindex, jl_value_t *jfty, size_t fsz, size_t al, MDNode *tbaa, bool mutabl)
2150-
{
2151-
Instruction *tindex0 = tbaa_decorate(tbaa_unionselbyte, ctx.builder.CreateAlignedLoad(T_int8, ptindex, Align(1)));
2152-
//tindex0->setMetadata(LLVMContext::MD_range, MDNode::get(jl_LLVMContext, {
2153-
// ConstantAsMetadata::get(ConstantInt::get(T_int8, 0)),
2154-
// ConstantAsMetadata::get(ConstantInt::get(T_int8, union_max)) }));
2149+
static jl_cgval_t emit_unionload(jl_codectx_t &ctx, Value *addr, Value *ptindex,
2150+
jl_value_t *jfty, size_t fsz, size_t al, MDNode *tbaa, bool mutabl,
2151+
unsigned union_max, MDNode *tbaa_ptindex)
2152+
{
2153+
Instruction *tindex0 = tbaa_decorate(tbaa_ptindex, ctx.builder.CreateAlignedLoad(T_int8, ptindex, Align(1)));
2154+
tindex0->setMetadata(LLVMContext::MD_range, MDNode::get(jl_LLVMContext, {
2155+
ConstantAsMetadata::get(ConstantInt::get(T_int8, 0)),
2156+
ConstantAsMetadata::get(ConstantInt::get(T_int8, union_max)) }));
21552157
Value *tindex = ctx.builder.CreateNUWAdd(ConstantInt::get(T_int8, 1), tindex0);
2156-
if (mutabl) {
2158+
if (fsz > 0 && mutabl) {
21572159
// move value to an immutable stack slot (excluding tindex)
2158-
Type *ET = IntegerType::get(jl_LLVMContext, 8 * al);
2159-
AllocaInst *lv = emit_static_alloca(ctx, ET);
2160-
lv->setOperand(0, ConstantInt::get(T_int32, (fsz + al - 1) / al));
2160+
Type *AT = ArrayType::get(IntegerType::get(jl_LLVMContext, 8 * al), (fsz + al - 1) / al);
2161+
AllocaInst *lv = emit_static_alloca(ctx, AT);
2162+
if (al > 1)
2163+
lv->setAlignment(Align(al));
21612164
emit_memcpy(ctx, lv, tbaa, addr, tbaa, fsz, al);
21622165
addr = lv;
21632166
}
2164-
return mark_julia_slot(addr, jfty, tindex, tbaa);
2167+
return mark_julia_slot(fsz > 0 ? addr : nullptr, jfty, tindex, tbaa);
21652168
}
21662169

21672170
// If `nullcheck` is not NULL and a pointer NULL check is necessary
@@ -2235,7 +2238,8 @@ static jl_cgval_t emit_getfield_knownidx(jl_codectx_t &ctx, const jl_cgval_t &st
22352238
}
22362239
else if (jl_is_uniontype(jfty)) {
22372240
size_t fsz = 0, al = 0;
2238-
bool isptr = !jl_islayout_inline(jfty, &fsz, &al);
2241+
int union_max = jl_islayout_inline(jfty, &fsz, &al);
2242+
bool isptr = (union_max == 0);
22392243
assert(!isptr && fsz == jl_field_size(jt, idx) - 1); (void)isptr;
22402244
Value *ptindex;
22412245
if (isboxed) {
@@ -2245,7 +2249,7 @@ static jl_cgval_t emit_getfield_knownidx(jl_codectx_t &ctx, const jl_cgval_t &st
22452249
else {
22462250
ptindex = emit_struct_gep(ctx, cast<StructType>(lt), staddr, byte_offset + fsz);
22472251
}
2248-
return emit_unionload(ctx, addr, ptindex, jfty, fsz, al, tbaa, jt->name->mutabl);
2252+
return emit_unionload(ctx, addr, ptindex, jfty, fsz, al, tbaa, jt->name->mutabl, union_max, tbaa_unionselbyte);
22492253
}
22502254
assert(jl_is_concrete_type(jfty));
22512255
if (!jt->name->mutabl && !(maybe_null && (jfty == (jl_value_t*)jl_bool_type ||
@@ -3306,7 +3310,8 @@ static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
33063310
jl_value_t *jfty = jl_field_type(sty, idx0);
33073311
if (!jl_field_isptr(sty, idx0) && jl_is_uniontype(jfty)) {
33083312
size_t fsz = 0, al = 0;
3309-
bool isptr = !jl_islayout_inline(jfty, &fsz, &al);
3313+
int union_max = jl_islayout_inline(jfty, &fsz, &al);
3314+
bool isptr = (union_max == 0);
33103315
assert(!isptr && fsz == jl_field_size(sty, idx0) - 1); (void)isptr;
33113316
// compute tindex from rhs
33123317
jl_cgval_t rhs_union = convert_julia_type(ctx, rhs, jfty);
@@ -3323,7 +3328,7 @@ static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
33233328
}
33243329
jl_cgval_t oldval = rhs;
33253330
if (!issetfield)
3326-
oldval = emit_unionload(ctx, addr, ptindex, jfty, fsz, al, strct.tbaa, true);
3331+
oldval = emit_unionload(ctx, addr, ptindex, jfty, fsz, al, strct.tbaa, true, union_max, tbaa_unionselbyte);
33273332
Value *Success = NULL;
33283333
BasicBlock *DoneBB = NULL;
33293334
if (isreplacefield || ismodifyfield) {
@@ -3342,13 +3347,13 @@ static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
33423347
emit_typecheck(ctx, rhs, jfty, fname);
33433348
rhs = update_julia_type(ctx, rhs, jfty);
33443349
}
3345-
rhs_union = convert_julia_type(ctx, rhs, jfty);
3350+
rhs_union = convert_julia_type(ctx, rhs, jfty);
33463351
if (rhs_union.typ == jl_bottom_type)
33473352
return jl_cgval_t();
33483353
if (needlock)
33493354
emit_lockstate_value(ctx, strct, true);
33503355
cmp = oldval;
3351-
oldval = emit_unionload(ctx, addr, ptindex, jfty, fsz, al, strct.tbaa, true);
3356+
oldval = emit_unionload(ctx, addr, ptindex, jfty, fsz, al, strct.tbaa, true, union_max, tbaa_unionselbyte);
33523357
}
33533358
BasicBlock *XchgBB = BasicBlock::Create(jl_LLVMContext, "xchg", ctx.f);
33543359
DoneBB = BasicBlock::Create(jl_LLVMContext, "done_xchg", ctx.f);

src/codegen.cpp

Lines changed: 46 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2180,34 +2180,19 @@ static jl_cgval_t emit_globalref(jl_codectx_t &ctx, jl_module_t *mod, jl_sym_t *
21802180
static Value *emit_box_compare(jl_codectx_t &ctx, const jl_cgval_t &arg1, const jl_cgval_t &arg2,
21812181
Value *nullcheck1, Value *nullcheck2)
21822182
{
2183-
// If either sides is boxed or can be trivially boxed,
2184-
// we'll prefer to do a pointer check.
2185-
// At this point, we know that at least one of the arguments isn't a constant
2186-
// so a runtime content check will involve at least one load from the
2187-
// pointer (and likely a type check)
2188-
// so a pointer comparison should be no worse than that even in imaging mode
2189-
// when the constant pointer has to be loaded.
2190-
// Note that we ignore nullcheck, since in the case where it may be set, we
2191-
// also knew the types of both fields must be the same so there cannot be
2192-
// any unboxed values on either side.
2193-
if ((!arg1.TIndex && jl_pointer_egal(arg1.typ)) || (!arg2.TIndex && jl_pointer_egal(arg2.typ))) {
2194-
// n.b. Vboxed may be incomplete if Tindex is set (missing singletons)
2195-
// and Vboxed == isboxed || Tindex
2196-
if ((arg1.Vboxed || arg1.constant) && (arg2.Vboxed || arg2.constant)) {
2197-
Value *varg1 = arg1.constant ? literal_pointer_val(ctx, arg1.constant) : maybe_bitcast(ctx, arg1.Vboxed, T_pjlvalue);
2198-
Value *varg2 = arg2.constant ? literal_pointer_val(ctx, arg2.constant) : maybe_bitcast(ctx, arg2.Vboxed, T_pjlvalue);
2199-
return ctx.builder.CreateICmpEQ(decay_derived(ctx, varg1), decay_derived(ctx, varg2));
2200-
}
2201-
return ConstantInt::get(T_int1, 0); // seems probably unreachable?
2202-
// (since intersection of rt1 and rt2 is non-empty here, so we should have
2203-
// a value in this intersection, but perhaps intersection might have failed)
2183+
if (jl_pointer_egal(arg1.typ) || jl_pointer_egal(arg2.typ)) {
2184+
// if we can be certain we won't try to load from the pointer (because
2185+
// we know boxed is trivial), we can skip the separate null checks
2186+
// and just do the ICmpEQ test
2187+
if (!arg1.TIndex && !arg2.TIndex)
2188+
nullcheck1 = nullcheck2 = nullptr;
22042189
}
2205-
22062190
return emit_nullcheck_guard2(ctx, nullcheck1, nullcheck2, [&] {
2207-
Value *varg1 = arg1.constant ? literal_pointer_val(ctx, arg1.constant) : maybe_bitcast(ctx, value_to_pointer(ctx, arg1).V, T_pjlvalue);
2208-
Value *varg2 = arg2.constant ? literal_pointer_val(ctx, arg2.constant) : maybe_bitcast(ctx, value_to_pointer(ctx, arg2).V, T_pjlvalue);
2209-
varg1 = decay_derived(ctx, varg1);
2210-
varg2 = decay_derived(ctx, varg2);
2191+
Value *varg1 = decay_derived(ctx, boxed(ctx, arg1));
2192+
Value *varg2 = decay_derived(ctx, boxed(ctx, arg2));
2193+
if (jl_pointer_egal(arg1.typ) || jl_pointer_egal(arg2.typ)) {
2194+
return ctx.builder.CreateICmpEQ(varg1, varg2);
2195+
}
22112196
Value *neq = ctx.builder.CreateICmpNE(varg1, varg2);
22122197
return emit_guarded_test(ctx, neq, true, [&] {
22132198
Value *dtarg = emit_typeof_boxed(ctx, arg1);
@@ -2733,28 +2718,28 @@ static bool emit_builtin_call(jl_codectx_t &ctx, jl_cgval_t *ret, jl_value_t *f,
27332718
*ret = ghostValue(ety);
27342719
}
27352720
else if (!isboxed && jl_is_uniontype(ety)) {
2736-
Type *AT = ArrayType::get(IntegerType::get(jl_LLVMContext, 8 * al), (elsz + al - 1) / al);
2737-
Value *data = emit_bitcast(ctx, emit_arrayptr(ctx, ary, ary_ex), AT->getPointerTo());
2738-
// isbits union selector bytes are stored after a->maxsize
2739-
Value *ndims = (nd == -1 ? emit_arrayndims(ctx, ary) : ConstantInt::get(T_int16, nd));
2740-
Value *is_vector = ctx.builder.CreateICmpEQ(ndims, ConstantInt::get(T_int16, 1));
2721+
Value *data = emit_arrayptr(ctx, ary, ary_ex);
27412722
Value *offset = emit_arrayoffset(ctx, ary, nd);
2742-
Value *selidx_v = ctx.builder.CreateSub(emit_vectormaxsize(ctx, ary), ctx.builder.CreateZExt(offset, T_size));
2743-
Value *selidx_m = emit_arraylen(ctx, ary);
2744-
Value *selidx = ctx.builder.CreateSelect(is_vector, selidx_v, selidx_m);
2745-
Value *ptindex = ctx.builder.CreateInBoundsGEP(AT, data, selidx);
2723+
Value *ptindex;
2724+
if (elsz == 0) {
2725+
ptindex = data;
2726+
}
2727+
else {
2728+
Type *AT = ArrayType::get(IntegerType::get(jl_LLVMContext, 8 * al), (elsz + al - 1) / al);
2729+
data = emit_bitcast(ctx, data, AT->getPointerTo());
2730+
// isbits union selector bytes are stored after a->maxsize
2731+
Value *ndims = (nd == -1 ? emit_arrayndims(ctx, ary) : ConstantInt::get(T_int16, nd));
2732+
Value *is_vector = ctx.builder.CreateICmpEQ(ndims, ConstantInt::get(T_int16, 1));
2733+
Value *selidx_v = ctx.builder.CreateSub(emit_vectormaxsize(ctx, ary), ctx.builder.CreateZExt(offset, T_size));
2734+
Value *selidx_m = emit_arraylen(ctx, ary);
2735+
Value *selidx = ctx.builder.CreateSelect(is_vector, selidx_v, selidx_m);
2736+
ptindex = ctx.builder.CreateInBoundsGEP(AT, data, selidx);
2737+
data = ctx.builder.CreateInBoundsGEP(AT, data, idx);
2738+
}
27462739
ptindex = emit_bitcast(ctx, ptindex, T_pint8);
27472740
ptindex = ctx.builder.CreateInBoundsGEP(T_int8, ptindex, offset);
27482741
ptindex = ctx.builder.CreateInBoundsGEP(T_int8, ptindex, idx);
2749-
Instruction *tindex = tbaa_decorate(tbaa_arrayselbyte, ctx.builder.CreateAlignedLoad(T_int8, ptindex, Align(1)));
2750-
tindex->setMetadata(LLVMContext::MD_range, MDNode::get(jl_LLVMContext, {
2751-
ConstantAsMetadata::get(ConstantInt::get(T_int8, 0)),
2752-
ConstantAsMetadata::get(ConstantInt::get(T_int8, union_max)) }));
2753-
AllocaInst *lv = emit_static_alloca(ctx, AT);
2754-
if (al > 1)
2755-
lv->setAlignment(Align(al));
2756-
emit_memcpy(ctx, lv, tbaa_arraybuf, ctx.builder.CreateInBoundsGEP(AT, data, idx), tbaa_arraybuf, elsz, al, false);
2757-
*ret = mark_julia_slot(lv, ety, ctx.builder.CreateNUWAdd(ConstantInt::get(T_int8, 1), tindex), tbaa_arraybuf);
2742+
*ret = emit_unionload(ctx, data, ptindex, ety, elsz, al, tbaa_arraybuf, true, union_max, tbaa_arrayselbyte);
27582743
}
27592744
else {
27602745
MDNode *aliasscope = (f == jl_builtin_const_arrayref) ? ctx.aliasscope : nullptr;
@@ -2840,28 +2825,31 @@ static bool emit_builtin_call(jl_codectx_t &ctx, jl_cgval_t *ret, jl_value_t *f,
28402825
if (!isboxed && jl_is_uniontype(ety)) {
28412826
Type *AT = ArrayType::get(IntegerType::get(jl_LLVMContext, 8 * al), (elsz + al - 1) / al);
28422827
Value *data = emit_bitcast(ctx, emit_arrayptr(ctx, ary, ary_ex), AT->getPointerTo());
2828+
Value *offset = emit_arrayoffset(ctx, ary, nd);
28432829
// compute tindex from val
28442830
jl_cgval_t rhs_union = convert_julia_type(ctx, val, ety);
28452831
Value *tindex = compute_tindex_unboxed(ctx, rhs_union, ety);
28462832
tindex = ctx.builder.CreateNUWSub(tindex, ConstantInt::get(T_int8, 1));
2847-
Value *ndims = (nd == -1 ? emit_arrayndims(ctx, ary) : ConstantInt::get(T_int16, nd));
2848-
Value *is_vector = ctx.builder.CreateICmpEQ(ndims, ConstantInt::get(T_int16, 1));
2849-
Value *offset = emit_arrayoffset(ctx, ary, nd);
2850-
Value *selidx_v = ctx.builder.CreateSub(emit_vectormaxsize(ctx, ary), ctx.builder.CreateZExt(offset, T_size));
2851-
Value *selidx_m = emit_arraylen(ctx, ary);
2852-
Value *selidx = ctx.builder.CreateSelect(is_vector, selidx_v, selidx_m);
2853-
Value *ptindex = ctx.builder.CreateInBoundsGEP(AT, data, selidx);
2833+
Value *ptindex;
2834+
if (elsz == 0) {
2835+
ptindex = data;
2836+
}
2837+
else {
2838+
Value *ndims = (nd == -1 ? emit_arrayndims(ctx, ary) : ConstantInt::get(T_int16, nd));
2839+
Value *is_vector = ctx.builder.CreateICmpEQ(ndims, ConstantInt::get(T_int16, 1));
2840+
Value *selidx_v = ctx.builder.CreateSub(emit_vectormaxsize(ctx, ary), ctx.builder.CreateZExt(offset, T_size));
2841+
Value *selidx_m = emit_arraylen(ctx, ary);
2842+
Value *selidx = ctx.builder.CreateSelect(is_vector, selidx_v, selidx_m);
2843+
ptindex = ctx.builder.CreateInBoundsGEP(AT, data, selidx);
2844+
data = ctx.builder.CreateInBoundsGEP(AT, data, idx);
2845+
}
28542846
ptindex = emit_bitcast(ctx, ptindex, T_pint8);
28552847
ptindex = ctx.builder.CreateInBoundsGEP(T_int8, ptindex, offset);
28562848
ptindex = ctx.builder.CreateInBoundsGEP(T_int8, ptindex, idx);
28572849
tbaa_decorate(tbaa_arrayselbyte, ctx.builder.CreateStore(tindex, ptindex));
2858-
if (jl_is_datatype(val.typ) && jl_datatype_size(val.typ) == 0) {
2859-
// no-op
2860-
}
2861-
else {
2862-
// copy data
2863-
Value *addr = ctx.builder.CreateInBoundsGEP(AT, data, idx);
2864-
emit_unionmove(ctx, addr, tbaa_arraybuf, val, nullptr);
2850+
if (elsz > 0 && (!jl_is_datatype(val.typ) || jl_datatype_size(val.typ) > 0)) {
2851+
// copy data (if any)
2852+
emit_unionmove(ctx, data, tbaa_arraybuf, val, nullptr);
28652853
}
28662854
}
28672855
else {

src/rtutils.c

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,12 +1014,14 @@ static size_t jl_static_show_x_(JL_STREAM *out, jl_value_t *v, jl_datatype_t *vt
10141014
n += jl_printf(out, ")}[");
10151015
size_t j, tlen = jl_array_len(v);
10161016
jl_array_t *av = (jl_array_t*)v;
1017-
jl_datatype_t *el_type = (jl_datatype_t*)jl_tparam0(vt);
1017+
jl_value_t *el_type = jl_tparam0(vt);
1018+
char *typetagdata = (!av->flags.ptrarray && jl_is_uniontype(el_type)) ? jl_array_typetagdata(av) : NULL;
10181019
int nlsep = 0;
10191020
if (av->flags.ptrarray) {
10201021
// print arrays with newlines, unless the elements are probably small
10211022
for (j = 0; j < tlen; j++) {
1022-
jl_value_t *p = jl_array_ptr_ref(av, j);
1023+
jl_value_t **ptr = ((jl_value_t**)av->data) + j;
1024+
jl_value_t *p = *ptr;
10231025
if (p != NULL && (uintptr_t)p >= 4096U) {
10241026
jl_value_t *p_ty = jl_typeof(p);
10251027
if ((uintptr_t)p_ty >= 4096U) {
@@ -1035,11 +1037,14 @@ static size_t jl_static_show_x_(JL_STREAM *out, jl_value_t *v, jl_datatype_t *vt
10351037
n += jl_printf(out, "\n ");
10361038
for (j = 0; j < tlen; j++) {
10371039
if (av->flags.ptrarray) {
1038-
n += jl_static_show_x(out, jl_array_ptr_ref(v, j), depth);
1040+
jl_value_t **ptr = ((jl_value_t**)av->data) + j;
1041+
n += jl_static_show_x(out, *ptr, depth);
10391042
}
10401043
else {
10411044
char *ptr = ((char*)av->data) + j * av->elsize;
1042-
n += jl_static_show_x_(out, (jl_value_t*)ptr, el_type, depth);
1045+
n += jl_static_show_x_(out, (jl_value_t*)ptr,
1046+
typetagdata ? (jl_datatype_t*)jl_nth_union_component(el_type, typetagdata[j]) : (jl_datatype_t*)el_type,
1047+
depth);
10431048
}
10441049
if (j != tlen - 1)
10451050
n += jl_printf(out, nlsep ? ",\n " : ", ");

test/compiler/codegen.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,3 +672,13 @@ function f42645()
672672
res
673673
end
674674
@test ((f42645()::B42645).y::A42645{Int}).x
675+
676+
# issue #43123
677+
@noinline cmp43123(a::Some, b::Some) = something(a) === something(b)
678+
@noinline cmp43123(a, b) = a[] === b[]
679+
@test cmp43123(Some{Function}(+), Some{Union{typeof(+), typeof(-)}}(+))
680+
@test !cmp43123(Some{Function}(+), Some{Union{typeof(+), typeof(-)}}(-))
681+
@test cmp43123(Ref{Function}(+), Ref{Union{typeof(+), typeof(-)}}(+))
682+
@test !cmp43123(Ref{Function}(+), Ref{Union{typeof(+), typeof(-)}}(-))
683+
@test cmp43123(Function[+], Union{typeof(+), typeof(-)}[+])
684+
@test !cmp43123(Function[+], Union{typeof(+), typeof(-)}[-])

0 commit comments

Comments
 (0)