Skip to content

Commit 4912d77

Browse files
authored
[mypyc] Add inline primitives for bytes.__getitem__ (#20552)
These are similar to the ones we have for `BytesWriter`. These are a bit faster that the old primitive in typical cases (based on a microbenchmark), but in an ideal scenario these allow vectorization by the C compiler and something like 10x performance improvement. I don't remove the old primitive yet, since it's used in some places which can't currently use the specializer. I'll look at running the specializer in more contexts as a follow-up PR. This PR should cover some of the most common use cases. Also refactor to share some code with `BytesWriter` specialization.
1 parent 430924a commit 4912d77

File tree

4 files changed

+143
-19
lines changed

4 files changed

+143
-19
lines changed

mypyc/irbuild/specialize.py

Lines changed: 70 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,12 @@
9595
tokenizer_format_call,
9696
)
9797
from mypyc.primitives.bytearray_ops import isinstance_bytearray
98-
from mypyc.primitives.bytes_ops import isinstance_bytes
98+
from mypyc.primitives.bytes_ops import (
99+
bytes_adjust_index_op,
100+
bytes_get_item_unsafe_op,
101+
bytes_range_check_op,
102+
isinstance_bytes,
103+
)
99104
from mypyc.primitives.dict_ops import (
100105
dict_items_op,
101106
dict_keys_op,
@@ -1207,30 +1212,50 @@ def translate_object_setattr(builder: IRBuilder, expr: CallExpr, callee: RefExpr
12071212
return builder.call_c(generic_setattr, [self_reg, name_reg, value], expr.line)
12081213

12091214

1210-
@specialize_dunder("__getitem__", bytes_writer_rprimitive)
1211-
def translate_bytes_writer_get_item(
1212-
builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression
1215+
def translate_getitem_with_bounds_check(
1216+
builder: IRBuilder,
1217+
base_expr: Expression,
1218+
args: list[Expression],
1219+
ctx_expr: Expression,
1220+
adjust_index_op: PrimitiveDescription,
1221+
range_check_op: PrimitiveDescription,
1222+
get_item_unsafe_op: PrimitiveDescription,
12131223
) -> Value | None:
1214-
"""Optimized BytesWriter.__getitem__ implementation with bounds checking."""
1224+
"""Shared helper for optimized __getitem__ with bounds checking.
1225+
1226+
This implements the common pattern of:
1227+
1. Adjusting negative indices
1228+
2. Checking if index is in valid range
1229+
3. Raising IndexError if out of range
1230+
4. Getting the item if in range
1231+
1232+
Args:
1233+
builder: The IR builder
1234+
base_expr: The base object expression
1235+
args: The arguments to __getitem__ (should be length 1)
1236+
ctx_expr: The context expression for line numbers
1237+
adjust_index_op: Primitive op to adjust negative indices
1238+
range_check_op: Primitive op to check if index is in valid range
1239+
get_item_unsafe_op: Primitive op to get item (no bounds checking)
1240+
1241+
Returns:
1242+
The result value, or None if optimization doesn't apply
1243+
"""
12151244
# Check that we have exactly one argument
12161245
if len(args) != 1:
12171246
return None
12181247

1219-
# Get the BytesWriter object
1248+
# Get the object
12201249
obj = builder.accept(base_expr)
12211250

12221251
# Get the index argument
12231252
index = builder.accept(args[0])
12241253

12251254
# Adjust the index (handle negative indices)
1226-
adjusted_index = builder.primitive_op(
1227-
bytes_writer_adjust_index_op, [obj, index], ctx_expr.line
1228-
)
1255+
adjusted_index = builder.primitive_op(adjust_index_op, [obj, index], ctx_expr.line)
12291256

12301257
# Check if the adjusted index is in valid range
1231-
range_check = builder.primitive_op(
1232-
bytes_writer_range_check_op, [obj, adjusted_index], ctx_expr.line
1233-
)
1258+
range_check = builder.primitive_op(range_check_op, [obj, adjusted_index], ctx_expr.line)
12341259

12351260
# Create blocks for branching
12361261
valid_block = BasicBlock()
@@ -1247,13 +1272,27 @@ def translate_bytes_writer_get_item(
12471272

12481273
# Handle valid index - get the item
12491274
builder.activate_block(valid_block)
1250-
result = builder.primitive_op(
1251-
bytes_writer_get_item_unsafe_op, [obj, adjusted_index], ctx_expr.line
1252-
)
1275+
result = builder.primitive_op(get_item_unsafe_op, [obj, adjusted_index], ctx_expr.line)
12531276

12541277
return result
12551278

12561279

1280+
@specialize_dunder("__getitem__", bytes_writer_rprimitive)
1281+
def translate_bytes_writer_get_item(
1282+
builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression
1283+
) -> Value | None:
1284+
"""Optimized BytesWriter.__getitem__ implementation with bounds checking."""
1285+
return translate_getitem_with_bounds_check(
1286+
builder,
1287+
base_expr,
1288+
args,
1289+
ctx_expr,
1290+
bytes_writer_adjust_index_op,
1291+
bytes_writer_range_check_op,
1292+
bytes_writer_get_item_unsafe_op,
1293+
)
1294+
1295+
12571296
@specialize_dunder("__setitem__", bytes_writer_rprimitive)
12581297
def translate_bytes_writer_set_item(
12591298
builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression
@@ -1300,3 +1339,19 @@ def translate_bytes_writer_set_item(
13001339
)
13011340

13021341
return builder.none()
1342+
1343+
1344+
@specialize_dunder("__getitem__", bytes_rprimitive)
1345+
def translate_bytes_get_item(
1346+
builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression
1347+
) -> Value | None:
1348+
"""Optimized bytes.__getitem__ implementation with bounds checking."""
1349+
return translate_getitem_with_bounds_check(
1350+
builder,
1351+
base_expr,
1352+
args,
1353+
ctx_expr,
1354+
bytes_adjust_index_op,
1355+
bytes_range_check_op,
1356+
bytes_get_item_unsafe_op,
1357+
)

mypyc/lib-rt/bytes_extra_ops.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,30 @@
22
#define MYPYC_BYTES_EXTRA_OPS_H
33

44
#include <Python.h>
5+
#include <stdint.h>
56
#include "CPy.h"
67

78
// Optimized bytes translate operation
89
PyObject *CPyBytes_Translate(PyObject *bytes, PyObject *table);
910

11+
// Optimized bytes.__getitem__ operations
12+
13+
// If index is negative, convert to non-negative index (no range checking)
14+
static inline int64_t CPyBytes_AdjustIndex(PyObject *obj, int64_t index) {
15+
if (index < 0) {
16+
return index + Py_SIZE(obj);
17+
}
18+
return index;
19+
}
20+
21+
// Check if index is in valid range [0, len)
22+
static inline bool CPyBytes_RangeCheck(PyObject *obj, int64_t index) {
23+
return index >= 0 && index < Py_SIZE(obj);
24+
}
25+
26+
// Get byte at index (no bounds checking) - returns as CPyTagged
27+
static inline CPyTagged CPyBytes_GetItemUnsafe(PyObject *obj, int64_t index) {
28+
return ((CPyTagged)(uint8_t)(PyBytes_AS_STRING(obj))[index]) << 1;
29+
}
30+
1031
#endif

mypyc/primitives/bytes_ops.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
c_int_rprimitive,
1313
c_pyssize_t_rprimitive,
1414
dict_rprimitive,
15+
int64_rprimitive,
1516
int_rprimitive,
1617
list_rprimitive,
1718
object_rprimitive,
@@ -21,6 +22,7 @@
2122
ERR_NEG_INT,
2223
binary_op,
2324
custom_op,
25+
custom_primitive_op,
2426
function_op,
2527
load_address_op,
2628
method_op,
@@ -148,3 +150,38 @@
148150
c_function_name="CPyBytes_Ord",
149151
error_kind=ERR_MAGIC,
150152
)
153+
154+
# Optimized bytes.__getitem__ operations
155+
156+
# bytes index adjustment - convert negative index to positive
157+
bytes_adjust_index_op = custom_primitive_op(
158+
name="bytes_adjust_index",
159+
arg_types=[bytes_rprimitive, int64_rprimitive],
160+
return_type=int64_rprimitive,
161+
c_function_name="CPyBytes_AdjustIndex",
162+
error_kind=ERR_NEVER,
163+
experimental=True,
164+
dependencies=[BYTES_EXTRA_OPS],
165+
)
166+
167+
# bytes range check - check if index is in valid range
168+
bytes_range_check_op = custom_primitive_op(
169+
name="bytes_range_check",
170+
arg_types=[bytes_rprimitive, int64_rprimitive],
171+
return_type=bool_rprimitive,
172+
c_function_name="CPyBytes_RangeCheck",
173+
error_kind=ERR_NEVER,
174+
experimental=True,
175+
dependencies=[BYTES_EXTRA_OPS],
176+
)
177+
178+
# bytes.__getitem__() - get byte at index (no bounds checking)
179+
bytes_get_item_unsafe_op = custom_primitive_op(
180+
name="bytes_get_item_unsafe",
181+
arg_types=[bytes_rprimitive, int64_rprimitive],
182+
return_type=int_rprimitive,
183+
c_function_name="CPyBytes_GetItemUnsafe",
184+
error_kind=ERR_NEVER,
185+
experimental=True,
186+
dependencies=[BYTES_EXTRA_OPS],
187+
)

mypyc/test-data/irbuild-bytes.test

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,26 @@ L0:
100100
return r0
101101

102102
[case testBytesIndex]
103-
def f(a: bytes, i: int) -> int:
103+
from mypy_extensions import i64
104+
105+
def f(a: bytes, i: i64) -> int:
104106
return a[i]
105107
[out]
106108
def f(a, i):
107109
a :: bytes
108-
i, r0 :: int
110+
i, r0 :: i64
111+
r1, r2 :: bool
112+
r3 :: int
109113
L0:
110-
r0 = CPyBytes_GetItem(a, i)
111-
return r0
114+
r0 = CPyBytes_AdjustIndex(a, i)
115+
r1 = CPyBytes_RangeCheck(a, r0)
116+
if r1 goto L2 else goto L1 :: bool
117+
L1:
118+
r2 = raise IndexError('index out of range')
119+
unreachable
120+
L2:
121+
r3 = CPyBytes_GetItemUnsafe(a, r0)
122+
return r3
112123

113124
[case testBytesConcat]
114125
def f(a: bytes, b: bytes) -> bytes:

0 commit comments

Comments
 (0)