Skip to content

Commit 48b9353

Browse files
junikimm717willow-ahrensmtsokol
authored
Testing for Safe Buffers (#171)
* reduce test cases for codegen * regression testing * ruff format * formatting stuff * fix * Update test output * Support windows platform reference files --------- Co-authored-by: Willow Ahrens <[email protected]> Co-authored-by: Mateusz Sokół <[email protected]> Co-authored-by: Mateusz Sokół <[email protected]>
1 parent 9ae1800 commit 48b9353

8 files changed

+193
-87
lines changed

src/finchlite/codegen/c.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def construct_from_c(fmt, c_obj):
202202
)
203203
# construction from c is just the identity.
204204
register_property(t, "construct_from_c", "__attr__", lambda fmt, c_value: c_value)
205+
register_property(t, "numba_type", "__attr__", lambda t: t)
205206

206207
register_property(
207208
np.generic,
@@ -529,24 +530,25 @@ def c_type(t):
529530
ctypes.c_wchar: ("wchar_t", ["wchar.h"]),
530531
ctypes.c_byte: ("char", []),
531532
ctypes.c_ubyte: ("unsigned char", []),
532-
ctypes.c_short: ("short", []),
533-
ctypes.c_ushort: ("unsigned short", []),
534-
ctypes.c_int: ("int", []),
535533
ctypes.c_int8: ("int8_t", ["stdint.h"]),
536534
ctypes.c_int16: ("int16_t", ["stdint.h"]),
537535
ctypes.c_int32: ("int32_t", ["stdint.h"]),
538536
ctypes.c_int64: ("int64_t", ["stdint.h"]),
539-
ctypes.c_uint: ("unsigned int", []),
540537
ctypes.c_uint8: ("uint8_t", ["stdint.h"]),
541538
ctypes.c_uint16: ("uint16_t", ["stdint.h"]),
542539
ctypes.c_uint32: ("uint32_t", ["stdint.h"]),
543540
ctypes.c_uint64: ("uint64_t", ["stdint.h"]),
544-
ctypes.c_long: ("long", []),
545-
ctypes.c_ulong: ("unsigned long", []),
546-
ctypes.c_longlong: ("long long", []),
547-
ctypes.c_ulonglong: ("unsigned long long", []),
548-
ctypes.c_size_t: ("size_t", ["stddef.h"]),
549-
ctypes.c_ssize_t: ("ssize_t", ["unistd.h"]),
541+
# use standard types instead of aliases
542+
# ctypes.c_short: ("short", []),
543+
# ctypes.c_ushort: ("unsigned short", []),
544+
# ctypes.c_int: ("int", []),
545+
# ctypes.c_uint: ("unsigned int", []),
546+
# ctypes.c_long: ("long", []),
547+
# ctypes.c_ulong: ("unsigned long", []),
548+
# ctypes.c_longlong: ("long long", []),
549+
# ctypes.c_ulonglong: ("unsigned long long", []),
550+
# ctypes.c_size_t: ("size_t", ["stddef.h"]),
551+
# ctypes.c_ssize_t: ("ssize_t", ["unistd.h"]),
550552
ctypes.c_float: ("float", []),
551553
ctypes.c_double: ("double", []),
552554
ctypes.c_char_p: ("char*", []),

src/finchlite/codegen/numpy_buffer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def c_unpack(self, ctx, var_n, val):
154154
data = ctx.freshen(var_n, "data")
155155
length = ctx.freshen(var_n, "length")
156156
t = ctx.ctype_name(c_type(self._dtype))
157+
ctx.add_header("#include <stddef.h>")
157158
ctx.exec(
158159
f"{ctx.feed}{t}* {data} = ({t}*){ctx(val)}->data;\n"
159160
f"{ctx.feed}size_t {length} = {ctx(val)}->length;"

src/finchlite/codegen/safe_buffer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,14 @@ def c_data(self, *args, **kwargs):
6565
def _c_check(self, ctx, buf, idx):
6666
ctx.add_header("#include <stdio.h>")
6767
ctx.add_header("#include <stdlib.h>")
68+
ctx.add_header("#include <stddef.h>")
6869
idx_n = ctx.freshen("computed")
6970
ctx.exec(
7071
f"{ctx.feed}size_t {idx_n} = ({ctx(idx)});\n"
71-
f"{ctx.feed}if ({idx_n} < 0 || {idx_n} >= ({self.c_length(ctx, buf)}))"
72-
"{"
73-
f'fprintf(stderr, "Encountered an index out of bounds error!");\n'
74-
f"exit(1);\n"
75-
"}"
72+
f"{ctx.feed}if ({idx_n} < 0 || {idx_n} >= ({self.c_length(ctx, buf)})) {{\n"
73+
f'{ctx.feed} fprintf(stderr, "Index out of bounds error!");\n'
74+
f"{ctx.feed} exit(1);\n"
75+
f"{ctx.feed}}}"
7676
)
7777
return asm.Variable(idx_n, ctypes.c_size_t)
7878

tests/reference/test_dot_product_regression_compiler0__c_NumpyBuffer_.c

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
#include <stddef.h>
2-
typedef void* (*fptr)( void**, size_t );
1+
#include <stdint.h>
2+
typedef void* (*fptr)( void**, uint64_t );
33
struct CNumpyBuffer {
44
void* arr;
55
void* data;
6-
size_t length;
6+
uint64_t length;
77
fptr resize;
88
};
9-
#include <unistd.h>
9+
#include <stddef.h>
1010
double dot_product(struct CNumpyBuffer* a, struct CNumpyBuffer* b) {
1111
double c = (double)0.0;
1212
struct CNumpyBuffer* a_ = a;
@@ -15,7 +15,7 @@ double dot_product(struct CNumpyBuffer* a, struct CNumpyBuffer* b) {
1515
struct CNumpyBuffer* b_ = b;
1616
double* b__data = (double*)b_->data;
1717
size_t b__length = b_->length;
18-
for (ssize_t i = (ssize_t)0; i < a__length; i++) {
18+
for (int64_t i = (int64_t)0; i < a__length; i++) {
1919
c = c + (a__data)[i] * (b__data)[i];
2020
}
2121
a_->data = (void*)a__data;
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#include <stdint.h>
2+
typedef void* (*fptr)( void**, uint64_t );
3+
struct CNumpyBuffer {
4+
void* arr;
5+
void* data;
6+
uint64_t length;
7+
fptr resize;
8+
};
9+
#include <stddef.h>
10+
#include <stdio.h>
11+
#include <stdlib.h>
12+
int64_t finch_access(struct CNumpyBuffer* a, uint64_t idx) {
13+
struct CNumpyBuffer* a_ = a;
14+
int64_t* a__data = (int64_t*)a_->data;
15+
size_t a__length = a_->length;
16+
size_t computed = (idx);
17+
if (computed < 0 || computed >= (a__length)) {
18+
fprintf(stderr, "Index out of bounds error!");
19+
exit(1);
20+
}
21+
int64_t val = (a__data)[computed];
22+
size_t computed_2 = (idx);
23+
if (computed_2 < 0 || computed_2 >= (a__length)) {
24+
fprintf(stderr, "Index out of bounds error!");
25+
exit(1);
26+
}
27+
int64_t val2 = (a__data)[computed_2];
28+
return val;
29+
}
30+
int64_t finch_change(struct CNumpyBuffer* a, uint64_t idx, int64_t val) {
31+
struct CNumpyBuffer* a_ = a;
32+
int64_t* a__data_2 = (int64_t*)a_->data;
33+
size_t a__length_2 = a_->length;
34+
size_t computed_3 = (idx);
35+
if (computed_3 < 0 || computed_3 >= (a__length_2)) {
36+
fprintf(stderr, "Index out of bounds error!");
37+
exit(1);
38+
}
39+
(a__data_2)[computed_3] = val;
40+
return (int64_t)0;
41+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import _operator, builtins
2+
from numba import njit
3+
import numpy
4+
from numpy import int64, float64
5+
6+
7+
@njit
8+
def finch_access(a: builtins.list, idx: ctypes.c_ulong) -> int64:
9+
a_ = a
10+
a__arr = a_[0]
11+
computed = (idx)
12+
if computed < 0 or computed >= (len(a__arr)):
13+
raise IndexError()
14+
val: int64 = a__arr[computed]
15+
computed_2 = (idx)
16+
if computed_2 < 0 or computed_2 >= (len(a__arr)):
17+
raise IndexError()
18+
val2: int64 = a__arr[computed_2]
19+
return val
20+
21+
@njit
22+
def finch_change(a: builtins.list, idx: ctypes.c_ulong, val: ctypes.c_long) -> int64:
23+
a_ = a
24+
a__arr_2 = a_[0]
25+
computed_3 = (idx)
26+
if computed_3 < 0 or computed_3 >= (len(a__arr_2)):
27+
raise IndexError()
28+
a__arr_2[computed_3] = val
29+
return c_long(0)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import _operator, builtins
2+
from numba import njit
3+
import numpy
4+
from numpy import int64, float64
5+
6+
7+
@njit
8+
def finch_access(a: builtins.list, idx: ctypes.c_ulonglong) -> int64:
9+
a_ = a
10+
a__arr = a_[0]
11+
computed = (idx)
12+
if computed < 0 or computed >= (len(a__arr)):
13+
raise IndexError()
14+
val: int64 = a__arr[computed]
15+
computed_2 = (idx)
16+
if computed_2 < 0 or computed_2 >= (len(a__arr)):
17+
raise IndexError()
18+
val2: int64 = a__arr[computed_2]
19+
return val
20+
21+
@njit
22+
def finch_change(a: builtins.list, idx: ctypes.c_ulonglong, val: ctypes.c_longlong) -> int64:
23+
a_ = a
24+
a__arr_2 = a_[0]
25+
computed_3 = (idx)
26+
if computed_3 < 0 or computed_3 >= (len(a__arr_2)):
27+
raise IndexError()
28+
a__arr_2[computed_3] = val
29+
return c_longlong(0)

tests/test_codegen.py

Lines changed: 71 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,67 @@ def test_simple_struct(compiler):
441441
assert result == np.float64(9.0)
442442

443443

444+
@pytest.mark.parametrize(
445+
["compiler", "extension", "platform"],
446+
[
447+
(CGenerator(), ".c", "any"),
448+
(NumbaGenerator(), ".py", "win" if sys.platform == "win32" else "any"),
449+
],
450+
)
451+
def test_safe_loadstore_regression(compiler, extension, platform, file_regression):
452+
a = np.array(range(3), dtype=ctypes.c_int64)
453+
ab = NumpyBuffer(a)
454+
ab_safe = SafeBuffer(ab)
455+
ab_v = asm.Variable("a", ab_safe.ftype)
456+
ab_slt = asm.Slot("a_", ab_safe.ftype)
457+
idx = asm.Variable("idx", ctypes.c_size_t)
458+
val = asm.Variable("val", ctypes.c_int64)
459+
460+
res_var = asm.Variable("val", ab_safe.ftype.element_type)
461+
res_var2 = asm.Variable("val2", ab_safe.ftype.element_type)
462+
mod = asm.Module(
463+
(
464+
asm.Function(
465+
asm.Variable("finch_access", ab_safe.ftype.element_type),
466+
(ab_v, idx),
467+
asm.Block(
468+
(
469+
asm.Unpack(ab_slt, ab_v),
470+
# we assign twice like this; this is intentional and
471+
# designed to check correct refreshing.
472+
asm.Assign(
473+
res_var,
474+
asm.Load(ab_slt, idx),
475+
),
476+
asm.Assign(
477+
res_var2,
478+
asm.Load(ab_slt, idx),
479+
),
480+
asm.Return(res_var),
481+
)
482+
),
483+
),
484+
asm.Function(
485+
asm.Variable("finch_change", ab_safe.ftype.element_type),
486+
(ab_v, idx, val),
487+
asm.Block(
488+
(
489+
asm.Unpack(ab_slt, ab_v),
490+
asm.Store(
491+
ab_slt,
492+
idx,
493+
val,
494+
),
495+
asm.Return(asm.Literal(ctypes.c_int64(0))),
496+
)
497+
),
498+
),
499+
)
500+
)
501+
output = compiler(mod)
502+
file_regression.check(output, extension=extension)
503+
504+
444505
@pytest.mark.parametrize(
445506
"size,idx",
446507
[(size, idx) for size in range(1, 4) for idx in range(-1, 4)],
@@ -472,34 +533,10 @@ def test_c_load_safebuffer(size, idx):
472533
[
473534
(*params, compiler)
474535
for params in [
475-
(
476-
-1,
477-
2,
478-
),
479-
(
480-
-1,
481-
3,
482-
),
483-
(
484-
0,
485-
2,
486-
),
487-
(
488-
1,
489-
2,
490-
),
491-
(
492-
2,
493-
3,
494-
),
495-
(
496-
2,
497-
2,
498-
),
499-
(
500-
3,
501-
2,
502-
),
536+
(-1, 2),
537+
(1, 2),
538+
(2, 3),
539+
(2, 2),
503540
]
504541
for compiler in [asm.AssemblyInterpreter(), NumbaCompiler()]
505542
],
@@ -548,12 +585,9 @@ def test_numba_load_safebuffer(size, idx, compiler):
548585
(*params, compiler)
549586
for params in [
550587
(-1, 2, 3),
551-
(-1, 3, 1434),
552-
(0, 2, 3),
553-
(1, 2, 3),
554-
(2, 3, 3),
588+
(1, 2, 1434),
589+
(2, 3, 1434),
555590
(2, 2, 3),
556-
(3, 2, 3),
557591
]
558592
for compiler in [NumbaCompiler(), asm.AssemblyInterpreter()]
559593
],
@@ -599,34 +633,10 @@ def test_numba_store_safebuffer(size, idx, value, compiler):
599633
[
600634
(*params, value)
601635
for params in [
602-
(
603-
-1,
604-
2,
605-
),
606-
(
607-
-1,
608-
3,
609-
),
610-
(
611-
0,
612-
2,
613-
),
614-
(
615-
1,
616-
2,
617-
),
618-
(
619-
2,
620-
3,
621-
),
622-
(
623-
2,
624-
2,
625-
),
626-
(
627-
3,
628-
2,
629-
),
636+
(-1, 2),
637+
(1, 2),
638+
(2, 3),
639+
(2, 2),
630640
]
631641
for value in [-1, 1434]
632642
],
@@ -663,9 +673,7 @@ def test_c_store_safebuffer(size, idx, value):
663673
"value,np_type,c_type",
664674
[
665675
(3, np.int64, ctypes.c_int64),
666-
(2, np.int32, ctypes.c_int32),
667676
(1, np.float32, ctypes.c_float),
668-
(1.0, np.float64, ctypes.c_double),
669677
(1.2, np.float64, ctypes.c_double),
670678
],
671679
)
@@ -682,9 +690,7 @@ def test_np_c_serialization(value, np_type, c_type):
682690
"value,c_type",
683691
[
684692
(3, ctypes.c_int64),
685-
(2, ctypes.c_int32),
686693
(1, ctypes.c_float),
687-
(1.0, ctypes.c_double),
688694
(1.2, ctypes.c_double),
689695
],
690696
)
@@ -702,9 +708,7 @@ def test_ctypes_c_serialization(value, c_type):
702708
"value,np_type",
703709
[
704710
(3, np.int64),
705-
(2, np.int32),
706711
(1, np.float32),
707-
(1.0, np.float64),
708712
(1.2, np.float64),
709713
],
710714
)

0 commit comments

Comments
 (0)