Skip to content

Commit 5177cee

Browse files
committed
fix: replace saturating add and sub with llvm intrinsics
Signed-off-by: Seth Stadick <[email protected]> fix: add tests for saturating sub and add Signed-off-by: Seth Stadick <[email protected]>
1 parent 9f2a943 commit 5177cee

File tree

2 files changed

+87
-11
lines changed

2 files changed

+87
-11
lines changed

ishlib/matcher/alignment/striped_utils.mojo

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from memory import memset_zero
2+
from sys import llvm_intrinsic
3+
24
from ishlib.matcher.alignment import AlignedMemory
35

46

@@ -8,26 +10,34 @@ fn saturating_sub[
810
](lhs: SIMD[data, width], rhs: SIMD[data, width]) -> SIMD[data, width]:
911
"""Saturating SIMD subtraction.
1012
11-
https://stackoverflow.com/questions/33481295/saturating-subtract-add-for-unsigned-bytes
13+
https://llvm.org/docs/LangRef.html#llvm-usub-sat-intrinsics
14+
https://llvm.org/docs/LangRef.html#llvm-ssub-sat-intrinsics
1215
"""
13-
constrained[data.is_unsigned()]()
14-
var resp = lhs - rhs
15-
resp &= -(resp <= lhs).cast[data]()
16-
return resp
16+
constrained[data.is_integral()]()
17+
18+
@parameter
19+
if data.is_unsigned():
20+
return llvm_intrinsic["llvm.usub.sat", __type_of(lhs)](lhs, rhs)
21+
else:
22+
return llvm_intrinsic["llvm.ssub.sat", __type_of(lhs)](lhs, rhs)
1723

1824

1925
@always_inline
2026
fn saturating_add[
2127
data: DType, width: Int
2228
](lhs: SIMD[data, width], rhs: SIMD[data, width]) -> SIMD[data, width]:
23-
"""Saturating SIMD subtraction.
29+
"""Saturating SIMD addition.
2430
25-
https://stackoverflow.com/questions/33481295/saturating-subtract-add-for-unsigned-bytes
31+
https://llvm.org/docs/LangRef.html#llvm-uadd-sat-intrinsics
32+
https://llvm.org/docs/LangRef.html#llvm-sadd-sat-intrinsics
2633
"""
27-
constrained[data.is_unsigned()]()
28-
var resp = lhs + rhs
29-
resp |= -(resp < lhs).cast[data]()
30-
return resp
34+
constrained[data.is_integral()]()
35+
36+
@parameter
37+
if data.is_unsigned():
38+
return llvm_intrinsic["llvm.uadd.sat", __type_of(lhs)](lhs, rhs)
39+
else:
40+
return llvm_intrinsic["llvm.sadd.sat", __type_of(lhs)](lhs, rhs)
3141

3242

3343
@value
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from testing import assert_equal
2+
3+
from ishlib.matcher.alignment.striped_utils import (
4+
saturating_add,
5+
saturating_sub,
6+
)
7+
8+
9+
def test_saturating_add():
10+
alias dtypes = List[DType](
11+
DType.uint8,
12+
DType.int8,
13+
DType.uint16,
14+
DType.int16,
15+
DType.uint32,
16+
DType.int32,
17+
)
18+
alias widths = List[Int](4, 8, 16, 32, 64, 128)
19+
20+
@parameter
21+
for i in range(0, len(dtypes)):
22+
23+
@parameter
24+
for j in range(0, len(widths)):
25+
alias dtype = dtypes[i]
26+
alias width = widths[j]
27+
alias MIN = Scalar[dtype].MIN
28+
alias MAX = Scalar[dtype].MAX
29+
30+
var lhs = SIMD[dtype, width](MAX)
31+
var rhs = SIMD[dtype, width](1)
32+
var expected = SIMD[dtype, width](MAX)
33+
assert_equal(saturating_add(lhs, rhs), expected)
34+
35+
expected = SIMD[dtype, width](2)
36+
assert_equal(saturating_add(rhs, rhs), expected)
37+
38+
39+
def test_saturating_sub():
40+
alias dtypes = List[DType](
41+
DType.uint8,
42+
DType.int8,
43+
DType.uint16,
44+
DType.int16,
45+
DType.uint32,
46+
DType.int32,
47+
)
48+
alias widths = List[Int](4, 8, 16, 32, 64, 128)
49+
50+
@parameter
51+
for i in range(0, len(dtypes)):
52+
53+
@parameter
54+
for j in range(0, len(widths)):
55+
alias dtype = dtypes[i]
56+
alias width = widths[j]
57+
alias MIN = Scalar[dtype].MIN
58+
alias MAX = Scalar[dtype].MAX
59+
60+
var lhs = SIMD[dtype, width](MIN)
61+
var rhs = SIMD[dtype, width](1)
62+
var expected = SIMD[dtype, width](MIN)
63+
assert_equal(saturating_sub(lhs, rhs), expected)
64+
65+
expected = SIMD[dtype, width](0)
66+
assert_equal(saturating_sub(rhs, rhs), expected)

0 commit comments

Comments
 (0)