Skip to content

Commit 2b90e9b

Browse files
committed
fix ret type
1 parent 5fb8788 commit 2b90e9b

File tree

2 files changed

+57
-12
lines changed

2 files changed

+57
-12
lines changed

sql/expression/bit_ops.go

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,8 @@ func (b *BitOp) Type() sql.Type {
9191
return lTyp
9292
}
9393

94-
if types.IsText(lTyp) || types.IsText(rTyp) {
95-
return types.Float64
96-
}
97-
98-
if types.IsUnsigned(lTyp) && types.IsUnsigned(rTyp) {
99-
return types.Uint64
100-
} else if types.IsSigned(lTyp) && types.IsSigned(rTyp) {
101-
return types.Int64
102-
}
103-
104-
return types.Float64
94+
// MySQL bitwise operations always return unsigned results, even for signed operands.
95+
return types.Uint64
10596
}
10697

10798
// CollationCoercibility implements the interface sql.CollationCoercible.
@@ -168,7 +159,20 @@ func (b *BitOp) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, inter
168159
}
169160

170161
func (b *BitOp) convertLeftRight(ctx *sql.Context, left interface{}, right interface{}) (interface{}, interface{}, error) {
171-
typ := b.Type()
162+
// Determine the appropriate conversion type based on operand types
163+
var typ sql.Type
164+
lTyp := b.LeftChild.Type()
165+
rTyp := b.RightChild.Type()
166+
167+
if types.IsText(lTyp) || types.IsText(rTyp) {
168+
typ = types.Float64
169+
} else if types.IsUnsigned(lTyp) && types.IsUnsigned(rTyp) {
170+
typ = types.Uint64
171+
} else if types.IsSigned(lTyp) && types.IsSigned(rTyp) {
172+
typ = types.Int64
173+
} else {
174+
typ = types.Float64
175+
}
172176

173177
left = convertValueToType(ctx, typ, left, types.IsTime(b.LeftChild.Type()))
174178
right = convertValueToType(ctx, typ, right, types.IsTime(b.RightChild.Type()))

sql/expression/bit_ops_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,44 @@ func TestAllUint64(t *testing.T) {
196196
})
197197
}
198198
}
199+
200+
func TestBitOpType(t *testing.T) {
201+
testCases := []struct {
202+
name string
203+
leftType sql.Type
204+
rightType sql.Type
205+
expectedType sql.Type
206+
description string
207+
}{
208+
{"unsigned & unsigned", types.Uint64, types.Uint64, types.Uint64, "Both operands are unsigned"},
209+
{"signed & signed", types.Int64, types.Int64, types.Uint64, "Both operands are signed - should return Uint64"},
210+
{"mixed signed & unsigned", types.Int64, types.Uint64, types.Uint64, "Mixed signed/unsigned operands"},
211+
{"text & text", types.Text, types.Text, types.Uint64, "Text operands should return Float64"},
212+
{"text & int", types.Text, types.Int64, types.Uint64, "Mixed text and numeric operands"},
213+
{"float & float", types.Float64, types.Float64, types.Uint64, "Float operands should return Uint64"},
214+
}
215+
216+
operations := []struct {
217+
name string
218+
op func(left, right sql.Expression) *BitOp
219+
}{
220+
{"BitAnd", NewBitAnd},
221+
{"BitOr", NewBitOr},
222+
{"BitXor", NewBitXor},
223+
{"ShiftLeft", NewShiftLeft},
224+
{"ShiftRight", NewShiftRight},
225+
}
226+
227+
for _, tt := range testCases {
228+
for _, op := range operations {
229+
t.Run(tt.name+"_"+op.name, func(t *testing.T) {
230+
require := require.New(t)
231+
bitOp := op.op(NewLiteral(1, tt.leftType), NewLiteral(1, tt.rightType))
232+
actualType := bitOp.Type()
233+
require.Equal(tt.expectedType, actualType,
234+
"BitOp.Type() should return %v for %s: %s",
235+
tt.expectedType, tt.name, tt.description)
236+
})
237+
}
238+
}
239+
}

0 commit comments

Comments
 (0)