From 1c99ff4c74eecd36c6c954dd113928b28a783813 Mon Sep 17 00:00:00 2001 From: Mark Reed Date: Sun, 23 Feb 2025 05:28:15 -0800 Subject: [PATCH 1/7] Add: GoLang support (#153) Co-authored-by: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> --- CONTRIBUTING.md | 20 +++++++ go/stringzilla/main.go | 127 +++++++++++++++++++++++++++++++++++++++++ scripts/bench.go | 70 +++++++++++++++++++++++ scripts/test.go | 59 +++++++++++++++++++ 4 files changed, 276 insertions(+) create mode 100644 go/stringzilla/main.go create mode 100644 scripts/bench.go create mode 100644 scripts/test.go diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 524d6c49..e6ed8331 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -471,6 +471,26 @@ cargo package --list --allow-dirty If you want to run benchmarks against third-party implementations, check out the [`ashvardanian/memchr_vs_stringzilla`](https://github.com/ashvardanian/memchr_vs_stringzilla/) repository. +## Contributing in Go + +```bash +export GO111MODULE="off" +go run scripts/test.go +go run scripts/bench.go +``` + +To run locally import with a relative path + +```bash + sz "../StringZilla/go/stringzilla" +``` + +And turn off GO111MODULE + +```bash +export GO111MODULE="off" +``` + ## General Recommendations ### Operations Not Worth Optimizing diff --git a/go/stringzilla/main.go b/go/stringzilla/main.go new file mode 100644 index 00000000..b9b566c3 --- /dev/null +++ b/go/stringzilla/main.go @@ -0,0 +1,127 @@ +package sz + +// #cgo CFLAGS: -g -mavx2 +// #include +// #include <../../include/stringzilla/stringzilla.h> +import "C" + +// -Wall -O3 + +import ( + "unsafe" +) + +/* +// Passing a C function pointer around in go isn't working +//type searchFunc func(*C.char, C.ulong, *C.char, C.ulong)C.sz_cptr_t +//func _search( str string, pat string, searchFunc func(*C.char, C.ulong, *C.char, C.ulong)C.sz_cptr_t) uintptr { +func _search( str string, pat string, searchFunc C.sz_find_t ) uintptr { + cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) + strlen := len(str) + patlen := len(pat) + ret := unsafe.Pointer( searchFunc(cstr, C.ulong(strlen), cpat, C.ulong(patlen)) ) + return ret +} +*/ + +func Contains(str string, pat string) bool { + cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) + strlen := len(str) + patlen := len(pat) + ret := unsafe.Pointer(C.sz_find(cstr, C.ulong(strlen), cpat, C.ulong(patlen))) + //ret := _search( str, pat, C.sz_find_t(C.sz_find) ) + return ret != nil +} + +func Index(str string, pat string) int64 { + cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) + strlen := len(str) + patlen := len(pat) + ret := unsafe.Pointer(C.sz_find(cstr, C.ulong(strlen), cpat, C.ulong(patlen))) + if ret == nil { + return 0 + } + return int64(uintptr(ret) - uintptr(unsafe.Pointer(cstr))) +} + +func Find(str string, pat string) int64 { + cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) + strlen := len(str) + patlen := len(pat) + ret := unsafe.Pointer(C.sz_find(cstr, C.ulong(strlen), cpat, C.ulong(patlen))) + if ret == nil { + return -1 + } + return int64(uintptr(ret) - uintptr(unsafe.Pointer(cstr))) +} + +func LastIndex(str string, pat string) int64 { + cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) + strlen := len(str) + patlen := len(pat) + ret := unsafe.Pointer(C.sz_rfind(cstr, C.ulong(strlen), cpat, C.ulong(patlen))) + if ret == nil { + return -1 + } + return int64(uintptr(ret) - uintptr(unsafe.Pointer(cstr))) +} +func RFind(str string, pat string) int64 { + return LastIndex(str, pat) +} + +func IndexAny(str string, charset string) int64 { + cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(charset))) + strlen := len(str) + patlen := len(charset) + ret := unsafe.Pointer(C.sz_find_char_from(cstr, C.ulong(strlen), cpat, C.ulong(patlen))) + if ret == nil { + return -1 + } + return int64(uintptr(ret) - uintptr(unsafe.Pointer(cstr))) +} +func FindCharFrom(str string, charset string) int64 { + return IndexAny(str, charset) +} + +func Count(str string, pat string, overlap bool) int64 { + cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) + strlen := int64(len(str)) + patlen := int64(len(pat)) + + if strlen == 0 || patlen == 0 || strlen < patlen { + return 0 + } + + count := int64(0) + if overlap == true { + for strlen > 0 { + ret := unsafe.Pointer(C.sz_find(cstr, C.ulong(strlen), cpat, C.ulong(patlen))) + if ret == nil { + break + } + count += 1 + strlen -= (1 + int64(uintptr(ret)-uintptr(unsafe.Pointer(cstr)))) + cstr = (*C.char)(unsafe.Add(ret, 1)) + } + } else { + for strlen > 0 { + ret := unsafe.Pointer(C.sz_find(cstr, C.ulong(strlen), cpat, C.ulong(patlen))) + if ret == nil { + break + } + count += 1 + strlen -= (patlen + int64(uintptr(ret)-uintptr(unsafe.Pointer(cstr)))) + cstr = (*C.char)(unsafe.Add(ret, patlen)) + } + } + + return count + +} diff --git a/scripts/bench.go b/scripts/bench.go new file mode 100644 index 00000000..a63380ea --- /dev/null +++ b/scripts/bench.go @@ -0,0 +1,70 @@ +package main + +import ( + "fmt" + "strings" + "time" + + sz "../go/stringzilla" +) + +func main() { + + str := strings.Repeat("0123456789", 10000) + "something" + pat := "some" + + fmt.Println("Contains") + t := time.Now() + for i := 0; i < 1; i++ { + strings.Contains(str, pat) + } + fmt.Println(" ", time.Since(t), "\tstrings.Contains") + + t = time.Now() + for i := 0; i < 1; i++ { + sz.Contains(str, pat) + } + fmt.Println(" ", time.Since(t), "\tsz.Contains") + + fmt.Println("Index") + t = time.Now() + for i := 0; i < 1; i++ { + strings.Index(str, pat) + } + fmt.Println(" ", time.Since(t), "\tstrings.Index") + + t = time.Now() + for i := 0; i < 1; i++ { + sz.Index(str, pat) + } + fmt.Println(" ", time.Since(t), "\tsz.Index") + + fmt.Println("IndexAny") + t = time.Now() + for i := 0; i < 1; i++ { + strings.IndexAny(str, pat) + } + fmt.Println(" ", time.Since(t), "\tstrings.IndexAny") + + t = time.Now() + for i := 0; i < 1; i++ { + sz.IndexAny(str, pat) + } + fmt.Println(" ", time.Since(t), "\tsz.IndexAny") + + str = strings.Repeat("0123456789", 100000) + "something" + pat = "123456789" + fmt.Println("Count") + t = time.Now() + for i := 0; i < 1; i++ { + strings.Count(str, pat) + } + fmt.Println(" ", time.Since(t), "\tstrings.Count") + + t = time.Now() + for i := 0; i < 1; i++ { + sz.Count(str, pat, false) + } + fmt.Println(" ", time.Since(t), "\tsz.Count") + +} diff --git a/scripts/test.go b/scripts/test.go new file mode 100644 index 00000000..07faa6f8 --- /dev/null +++ b/scripts/test.go @@ -0,0 +1,59 @@ +package main + +import ( + "fmt" + "runtime" + "strings" + + sz "../go/stringzilla" +) + +func assertEqual[T comparable](act T, exp T) int { + if exp == act { + return 0 + } + _, _, line, _ := runtime.Caller(1) + fmt.Println("") + fmt.Println(" ERROR line ", line, " expected (", exp, ") is not equal to actual (", act, ")") + return 1 +} + +func main() { + + str := strings.Repeat("0123456789", 100000) + "something" + pat := "some" + ret := 0 + + fmt.Print("Contains ... ") + ret |= assertEqual(sz.Contains("", ""), true) + ret |= assertEqual(sz.Contains("test", ""), true) + ret |= assertEqual(sz.Contains("test", "s"), true) + ret |= assertEqual(sz.Contains("test", "test"), true) + ret |= assertEqual(sz.Contains("test", "zest"), false) + ret |= assertEqual(sz.Contains("test", "z"), false) + if ret == 0 { + fmt.Println("successful") + } + + fmt.Print("Index ... ") + assertEqual(strings.Index(str, pat), int(sz.Index(str, pat))) + assertEqual(sz.Index("", ""), 0) + assertEqual(sz.Index("test", ""), 0) + assertEqual(sz.Index("test", "t"), 0) + assertEqual(sz.Index("test", "s"), 2) + fmt.Println("successful") + + fmt.Print("IndexAny ... ") + assertEqual(strings.IndexAny(str, pat), int(sz.IndexAny(str, pat))) + assertEqual(sz.IndexAny("test", "st"), 0) + assertEqual(sz.IndexAny("west east", "ta"), 3) + fmt.Println("successful") + + fmt.Print("Count ... ") + //assertEqual( strings.Count( str, pat ), int(sz.Count( str,pat,false )) ) + assertEqual(sz.Count("aaaaa", "a", false), 5) + assertEqual(sz.Count("aaaaa", "aa", false), 2) + assertEqual(sz.Count("aaaaa", "aa", true), 4) + fmt.Println("successful") + +} From 310f8c9e28b23063b0a5d487bfa8ce273b8ea54c Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 23 Feb 2025 13:38:57 +0000 Subject: [PATCH 2/7] Docs: Counting `.h` as C code lines --- .gitattributes | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .gitattributes diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..e0c79660 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,6 @@ +# GitHub's Linguist doesn't properly classify many languages +# https://github.com/github-linguist/linguist/blob/main/docs/overrides.md +*.h linguist-language=C +*.c linguist-language=C +*.hpp lingujson-language=C++ +*.cpp lingujson-language=C++ From 7bbb371874786c95fb42a747b40cf16846604d7e Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 23 Feb 2025 13:39:16 +0000 Subject: [PATCH 3/7] Make: Remove confusing `build.sh` --- scripts/build.sh | 65 ------------------------------------------------ 1 file changed, 65 deletions(-) delete mode 100755 scripts/build.sh diff --git a/scripts/build.sh b/scripts/build.sh deleted file mode 100755 index 600e5758..00000000 --- a/scripts/build.sh +++ /dev/null @@ -1,65 +0,0 @@ -#!/bin/bash -# This Bash script compiles the CMake-based project with different compilers for different verrsions of C++ -# This is what should happen if only GCC 12 is installed and we are running on Sapphire Rapids. -# -# cmake -DCMAKE_BUILD_TYPE=Release -DSTRINGZILLA_BUILD_BENCHMARK=1 \ -# -DCMAKE_CXX_COMPILER=g++-12 -DCMAKE_C_COMPILER=gcc-12 \ -# -DSTRINGZILLA_TARGET_ARCH="sandybridge" -B build_release/gcc-12-sandybridge && \ -# cmake --build build_release/gcc-12-sandybridge --config Release -# cmake -DCMAKE_BUILD_TYPE=Release -DSTRINGZILLA_BUILD_BENCHMARK=1 \ -# -DCMAKE_CXX_COMPILER=g++-12 -DCMAKE_C_COMPILER=gcc-12 \ -# -DSTRINGZILLA_TARGET_ARCH="haswell" -B build_release/gcc-12-haswell && \ -# cmake --build build_release/gcc-12-haswell --config Release -# cmake -DCMAKE_BUILD_TYPE=Release -DSTRINGZILLA_BUILD_BENCHMARK=1 \ -# -DCMAKE_CXX_COMPILER=g++-12 -DCMAKE_C_COMPILER=gcc-12 \ -# -DSTRINGZILLA_TARGET_ARCH="sapphirerapids" -B build_release/gcc-12-sapphirerapids && \ -# cmake --build build_release/gcc-12-sapphirerapids --config Release - -# Array of target architectures -declare -a architectures=("sandybridge" "haswell" "sapphirerapids") - -# Function to get installed versions of a compiler -get_versions() { - local compiler_prefix=$1 - local versions=() - - echo "Checking for compilers in /usr/bin with prefix: $compiler_prefix" - - # Check if the directory /usr/bin exists and is a directory - if [ -d "/usr/bin" ]; then - for version in /usr/bin/${compiler_prefix}-*; do - echo "Checking: $version" - if [[ -x "$version" ]]; then - local ver=${version##*-} - echo "Found compiler version: $ver" - versions+=("$ver") - fi - done - else - echo "/usr/bin does not exist or is not a directory" - fi - - echo ${versions[@]} -} - -# Get installed versions of GCC and Clang -gcc_versions=$(get_versions gcc) -clang_versions=$(get_versions clang) - -# Compile for each combination of compiler and architecture -for arch in "${ARCHS[@]}"; do - for gcc_version in $gcc_versions; do - cmake -DCMAKE_BUILD_TYPE=Release -DSTRINGZILLA_BUILD_BENCHMARK=1 \ - -DCMAKE_CXX_COMPILER=g++-$gcc_version -DCMAKE_C_COMPILER=gcc-$gcc_version \ - -DSTRINGZILLA_TARGET_ARCH="$arch" -B "build_release/gcc-$gcc_version-$arch" && \ - cmake --build "build_release/gcc-$gcc_version-$arch" --config Release - done - - for clang_version in $clang_versions; do - cmake -DCMAKE_BUILD_TYPE=Release -DSTRINGZILLA_BUILD_BENCHMARK=1 \ - -DCMAKE_CXX_COMPILER=clang++-$clang_version -DCMAKE_C_COMPILER=clang-$clang_version \ - -DSTRINGZILLA_TARGET_ARCH="$arch" -B "build_release/clang-$clang_version-$arch" && \ - cmake --build "build_release/clang-$clang_version-$arch" --config Release - done -done - From 25c3b5ce02f6ca02818b640fc8d785dc52762975 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 23 Feb 2025 15:08:55 +0000 Subject: [PATCH 4/7] Improve: GoLang 1.24 directives https://github.com/ashvardanian/SimSIMD/issues/28#issuecomment-2663031673 --- go/stringzilla/main.go | 213 +++++++++++++++++++++++++---------------- scripts/test.go | 2 - 2 files changed, 131 insertions(+), 84 deletions(-) diff --git a/go/stringzilla/main.go b/go/stringzilla/main.go index b9b566c3..b192dcff 100644 --- a/go/stringzilla/main.go +++ b/go/stringzilla/main.go @@ -1,127 +1,176 @@ +// StringZilla is a SIMD-accelerated string library modern CPUs, written in C 99, +// and using AVX2, AVX512, Arm NEON, and SVE intrinsics to accelerate processing. +// +// The GoLang binding is intended to provide a simple interface to a precompiled +// shared library, available on GitHub: https://github.com/ashvardanian/stringzilla +// +// It requires Go 1.24 or newer to leverage the `cGo` `noescape` and `nocallback` +// directives. Without those the latency of calling C functions from Go is too high +// to be useful for string processing. package sz -// #cgo CFLAGS: -g -mavx2 -// #include -// #include <../../include/stringzilla/stringzilla.h> +// #cgo CFLAGS: -O3 +// #cgo LDFLAGS: -lstringzilla_shared +// #cgo noescape sz_find +// #cgo noescape sz_find_byte +// #cgo noescape sz_rfind +// #cgo noescape sz_rfind_byte +// #cgo noescape sz_find_char_from +// #cgo noescape sz_rfind_char_from +// #cgo noescape sz_look_up_transform +// #cgo noescape sz_hamming_distance +// #cgo noescape sz_hamming_distance_utf8 +// #cgo noescape sz_edit_distance +// #cgo noescape sz_edit_distance_utf8 +// #cgo noescape sz_alignment_score +// #cgo nocallback sz_find +// #cgo nocallback sz_find_byte +// #cgo nocallback sz_rfind +// #cgo nocallback sz_rfind_byte +// #cgo nocallback sz_find_char_from +// #cgo nocallback sz_rfind_char_from +// #cgo nocallback sz_look_up_transform +// #cgo nocallback sz_hamming_distance +// #cgo nocallback sz_hamming_distance_utf8 +// #cgo nocallback sz_edit_distance +// #cgo nocallback sz_edit_distance_utf8 +// #cgo nocallback sz_alignment_score +// #define SZ_DYNAMIC_DISPATCH 1 +// #include import "C" +import "unsafe" -// -Wall -O3 - -import ( - "unsafe" -) - -/* -// Passing a C function pointer around in go isn't working -//type searchFunc func(*C.char, C.ulong, *C.char, C.ulong)C.sz_cptr_t -//func _search( str string, pat string, searchFunc func(*C.char, C.ulong, *C.char, C.ulong)C.sz_cptr_t) uintptr { -func _search( str string, pat string, searchFunc C.sz_find_t ) uintptr { - cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) - cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) - strlen := len(str) - patlen := len(pat) - ret := unsafe.Pointer( searchFunc(cstr, C.ulong(strlen), cpat, C.ulong(patlen)) ) - return ret +// Contains reports whether `substr` is within `str`. +// https://pkg.go.dev/strings#Contains +func Contains(str string, substr string) bool { + strPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + strLen := len(str) + substrPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(substr))) + substrLen := len(substr) + matchPtr := unsafe.Pointer(C.sz_find(strPtr, C.ulong(strLen), substrPtr, C.ulong(substrLen))) + return matchPtr != nil } -*/ -func Contains(str string, pat string) bool { - cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) - cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) - strlen := len(str) - patlen := len(pat) - ret := unsafe.Pointer(C.sz_find(cstr, C.ulong(strlen), cpat, C.ulong(patlen))) - //ret := _search( str, pat, C.sz_find_t(C.sz_find) ) - return ret != nil +// Index returns the index of the first instance of `substr` in `str`, or -1 if `substr` is not present. +// https://pkg.go.dev/strings#Index +func Index(str string, substr string) int64 { + strPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + strLen := len(str) + substrPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(substr))) + substrLen := len(substr) + matchPtr := unsafe.Pointer(C.sz_find(strPtr, C.ulong(strLen), substrPtr, C.ulong(substrLen))) + if matchPtr == nil { + return -1 + } + return int64(uintptr(matchPtr) - uintptr(unsafe.Pointer(strPtr))) } -func Index(str string, pat string) int64 { - cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) - cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) - strlen := len(str) - patlen := len(pat) - ret := unsafe.Pointer(C.sz_find(cstr, C.ulong(strlen), cpat, C.ulong(patlen))) - if ret == nil { - return 0 +// Index returns the index of the last instance of `substr` in `str`, or -1 if `substr` is not present. +// https://pkg.go.dev/strings#LastIndex +func LastIndex(str string, substr string) int64 { + strPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + strLen := len(str) + substrPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(substr))) + substrLen := len(substr) + matchPtr := unsafe.Pointer(C.sz_rfind(strPtr, C.ulong(strLen), substrPtr, C.ulong(substrLen))) + if matchPtr == nil { + return -1 } - return int64(uintptr(ret) - uintptr(unsafe.Pointer(cstr))) + return int64(uintptr(matchPtr) - uintptr(unsafe.Pointer(strPtr))) } -func Find(str string, pat string) int64 { - cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) - cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) - strlen := len(str) - patlen := len(pat) - ret := unsafe.Pointer(C.sz_find(cstr, C.ulong(strlen), cpat, C.ulong(patlen))) - if ret == nil { +// Index returns the index of the first instance of a byte in `str`, or -1 if a byte is not present. +// https://pkg.go.dev/strings#IndexByte +func IndexByte(str string, c byte) int64 { + strPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + strLen := len(str) + cPtr := (*C.char)(unsafe.Pointer(&c)) + matchPtr := unsafe.Pointer(C.sz_find_byte(strPtr, C.ulong(strLen), cPtr)) + if matchPtr == nil { return -1 } - return int64(uintptr(ret) - uintptr(unsafe.Pointer(cstr))) + return int64(uintptr(matchPtr) - uintptr(unsafe.Pointer(strPtr))) } -func LastIndex(str string, pat string) int64 { - cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) - cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) - strlen := len(str) - patlen := len(pat) - ret := unsafe.Pointer(C.sz_rfind(cstr, C.ulong(strlen), cpat, C.ulong(patlen))) - if ret == nil { +// Index returns the index of the last instance of a byte in `str`, or -1 if a byte is not present. +// https://pkg.go.dev/strings#LastIndexByte +func LastIndexByte(str string, c byte) int64 { + strPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + strLen := len(str) + cPtr := (*C.char)(unsafe.Pointer(&c)) + matchPtr := unsafe.Pointer(C.sz_rfind_byte(strPtr, C.ulong(strLen), cPtr)) + if matchPtr == nil { return -1 } - return int64(uintptr(ret) - uintptr(unsafe.Pointer(cstr))) -} -func RFind(str string, pat string) int64 { - return LastIndex(str, pat) + return int64(uintptr(matchPtr) - uintptr(unsafe.Pointer(strPtr))) } -func IndexAny(str string, charset string) int64 { - cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) - cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(charset))) - strlen := len(str) - patlen := len(charset) - ret := unsafe.Pointer(C.sz_find_char_from(cstr, C.ulong(strlen), cpat, C.ulong(patlen))) - if ret == nil { +// Index returns the index of the first instance of any byte from `substr` in `str`, or -1 if none are present. +// https://pkg.go.dev/strings#IndexAny +func IndexAny(str string, substr string) int64 { + strPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + strLen := len(str) + substrPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(substr))) + substrLen := len(substr) + matchPtr := unsafe.Pointer(C.sz_find_char_from(strPtr, C.ulong(strLen), substrPtr, C.ulong(substrLen))) + if matchPtr == nil { return -1 } - return int64(uintptr(ret) - uintptr(unsafe.Pointer(cstr))) + return int64(uintptr(matchPtr) - uintptr(unsafe.Pointer(strPtr))) } -func FindCharFrom(str string, charset string) int64 { - return IndexAny(str, charset) + +// Index returns the index of the last instance of any byte from `substr` in `str`, or -1 if none are present. +// https://pkg.go.dev/strings#LastIndexAny +func LastIndexAny(str string, substr string) int64 { + strPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + strLen := len(str) + substrPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(substr))) + substrLen := len(substr) + matchPtr := unsafe.Pointer(C.sz_rfind_char_from(strPtr, C.ulong(strLen), substrPtr, C.ulong(substrLen))) + if matchPtr == nil { + return -1 + } + return int64(uintptr(matchPtr) - uintptr(unsafe.Pointer(strPtr))) } -func Count(str string, pat string, overlap bool) int64 { - cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) - cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) - strlen := int64(len(str)) - patlen := int64(len(pat)) +// Count returns the number of overlapping or non-overlapping instances of `substr` in `str`. +// If `substr` is an empty string, returns 1 + the number of Unicode code points in `str`. +// https://pkg.go.dev/strings#Count +func Count(str string, substr string, overlap bool) int64 { + strPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + strLen := len(str) + substrPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(substr))) + substrLen := len(substr) - if strlen == 0 || patlen == 0 || strlen < patlen { + if substrLen == 0 { + return 1 + len([]rune(str)) + } + if strLen == 0 || strLen < substrLen { return 0 } count := int64(0) if overlap == true { - for strlen > 0 { - ret := unsafe.Pointer(C.sz_find(cstr, C.ulong(strlen), cpat, C.ulong(patlen))) + for strLen > 0 { + ret := unsafe.Pointer(C.sz_find(strPtr, C.ulong(strLen), substrPtr, C.ulong(substrLen))) if ret == nil { break } count += 1 - strlen -= (1 + int64(uintptr(ret)-uintptr(unsafe.Pointer(cstr)))) - cstr = (*C.char)(unsafe.Add(ret, 1)) + strLen -= (1 + int64(uintptr(ret)-uintptr(unsafe.Pointer(strPtr)))) + strPtr = (*C.char)(unsafe.Add(ret, 1)) } } else { - for strlen > 0 { - ret := unsafe.Pointer(C.sz_find(cstr, C.ulong(strlen), cpat, C.ulong(patlen))) + for strLen > 0 { + ret := unsafe.Pointer(C.sz_find(strPtr, C.ulong(strLen), substrPtr, C.ulong(substrLen))) if ret == nil { break } count += 1 - strlen -= (patlen + int64(uintptr(ret)-uintptr(unsafe.Pointer(cstr)))) - cstr = (*C.char)(unsafe.Add(ret, patlen)) + strLen -= (substrLen + int64(uintptr(ret)-uintptr(unsafe.Pointer(strPtr)))) + strPtr = (*C.char)(unsafe.Add(ret, substrLen)) } } return count - } diff --git a/scripts/test.go b/scripts/test.go index 07faa6f8..30418f91 100644 --- a/scripts/test.go +++ b/scripts/test.go @@ -25,8 +25,6 @@ func main() { ret := 0 fmt.Print("Contains ... ") - ret |= assertEqual(sz.Contains("", ""), true) - ret |= assertEqual(sz.Contains("test", ""), true) ret |= assertEqual(sz.Contains("test", "s"), true) ret |= assertEqual(sz.Contains("test", "test"), true) ret |= assertEqual(sz.Contains("test", "zest"), false) From 3949cf58bf4caac3394cf68c80599fc72772b9c5 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 23 Feb 2025 15:23:51 +0000 Subject: [PATCH 5/7] Make: Define the GoLang module --- .gitignore | 2 ++ .vscode/settings.json | 2 ++ CONTRIBUTING.md | 20 +++++++----- golang/go.mod | 3 ++ go/stringzilla/main.go => golang/lib.go | 41 ++++++++++--------------- scripts/test.go | 2 +- 6 files changed, 37 insertions(+), 33 deletions(-) create mode 100644 golang/go.mod rename go/stringzilla/main.go => golang/lib.go (82%) diff --git a/.gitignore b/.gitignore index fa4a8d0d..cbd6488a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,8 @@ build/ build_debug/ build_release/ +build_go/ +build_golang/ build_artifacts* # Yes, everyone loves keeping this file in the history. diff --git a/.vscode/settings.json b/.vscode/settings.json index 980956d1..ffdb7ff1 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -86,6 +86,8 @@ "Needleman", "newfunc", "NOARGS", + "nocallback", + "noescape", "noexcept", "NOMINMAX", "NOTIMPLEMENTED", diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e6ed8331..df053697 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -471,24 +471,30 @@ cargo package --list --allow-dirty If you want to run benchmarks against third-party implementations, check out the [`ashvardanian/memchr_vs_stringzilla`](https://github.com/ashvardanian/memchr_vs_stringzilla/) repository. -## Contributing in Go +## Contributing in GoLang + +First, precompile the C library: ```bash -export GO111MODULE="off" -go run scripts/test.go -go run scripts/bench.go +cmake -D STRINGZILLA_BUILD_SHARED=1 -D STRINGZILLA_BUILD_TEST=0 -D STRINGZILLA_BUILD_BENCHMARK=0 -B build_golang +cmake --build build_golang ``` -To run locally import with a relative path +Then, navigate to the GoLang module root directory and run the tests from there: ```bash - sz "../StringZilla/go/stringzilla" +CGO_CFLAGS="-I$(pwd)/../include" \ +CGO_LDFLAGS="-L$(pwd)/../build_golang -lstringzilla_shared" \ +LD_LIBRARY_PATH="$(pwd)/../build_golang:$LD_LIBRARY_PATH" \ +go run ../scripts/test.go ``` -And turn off GO111MODULE +Alternatively: ```bash export GO111MODULE="off" +go run scripts/test.go +go run scripts/bench.go ``` ## General Recommendations diff --git a/golang/go.mod b/golang/go.mod new file mode 100644 index 00000000..fc14b40e --- /dev/null +++ b/golang/go.mod @@ -0,0 +1,3 @@ +module github.com/ashvardanian/stringzilla/golang + +go 1.24 diff --git a/go/stringzilla/main.go b/golang/lib.go similarity index 82% rename from go/stringzilla/main.go rename to golang/lib.go index b192dcff..3d4f9114 100644 --- a/go/stringzilla/main.go +++ b/golang/lib.go @@ -7,34 +7,25 @@ // It requires Go 1.24 or newer to leverage the `cGo` `noescape` and `nocallback` // directives. Without those the latency of calling C functions from Go is too high // to be useful for string processing. +// +// Unlike the native Go `strings` package, StringZilla primarily targets byte-level +// binary data processing, with less emphasis on UTF-8 and locale-specific tasks. package sz // #cgo CFLAGS: -O3 -// #cgo LDFLAGS: -lstringzilla_shared +// #cgo LDFLAGS: -L. -L/usr/local/lib -lstringzilla_shared // #cgo noescape sz_find // #cgo noescape sz_find_byte // #cgo noescape sz_rfind // #cgo noescape sz_rfind_byte // #cgo noescape sz_find_char_from // #cgo noescape sz_rfind_char_from -// #cgo noescape sz_look_up_transform -// #cgo noescape sz_hamming_distance -// #cgo noescape sz_hamming_distance_utf8 -// #cgo noescape sz_edit_distance -// #cgo noescape sz_edit_distance_utf8 -// #cgo noescape sz_alignment_score // #cgo nocallback sz_find // #cgo nocallback sz_find_byte // #cgo nocallback sz_rfind // #cgo nocallback sz_rfind_byte // #cgo nocallback sz_find_char_from // #cgo nocallback sz_rfind_char_from -// #cgo nocallback sz_look_up_transform -// #cgo nocallback sz_hamming_distance -// #cgo nocallback sz_hamming_distance_utf8 -// #cgo nocallback sz_edit_distance -// #cgo nocallback sz_edit_distance_utf8 -// #cgo nocallback sz_alignment_score // #define SZ_DYNAMIC_DISPATCH 1 // #include import "C" @@ -134,16 +125,16 @@ func LastIndexAny(str string, substr string) int64 { } // Count returns the number of overlapping or non-overlapping instances of `substr` in `str`. -// If `substr` is an empty string, returns 1 + the number of Unicode code points in `str`. +// If `substr` is an empty string, returns 1 + the length of the `str`. // https://pkg.go.dev/strings#Count func Count(str string, substr string, overlap bool) int64 { strPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) - strLen := len(str) + strLen := int64(len(str)) substrPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(substr))) - substrLen := len(substr) + substrLen := int64(len(substr)) if substrLen == 0 { - return 1 + len([]rune(str)) + return 1 + strLen } if strLen == 0 || strLen < substrLen { return 0 @@ -152,23 +143,23 @@ func Count(str string, substr string, overlap bool) int64 { count := int64(0) if overlap == true { for strLen > 0 { - ret := unsafe.Pointer(C.sz_find(strPtr, C.ulong(strLen), substrPtr, C.ulong(substrLen))) - if ret == nil { + matchPtr := unsafe.Pointer(C.sz_find(strPtr, C.ulong(strLen), substrPtr, C.ulong(substrLen))) + if matchPtr == nil { break } count += 1 - strLen -= (1 + int64(uintptr(ret)-uintptr(unsafe.Pointer(strPtr)))) - strPtr = (*C.char)(unsafe.Add(ret, 1)) + strLen -= (1 + int64(uintptr(matchPtr)-uintptr(unsafe.Pointer(strPtr)))) + strPtr = (*C.char)(unsafe.Add(matchPtr, 1)) } } else { for strLen > 0 { - ret := unsafe.Pointer(C.sz_find(strPtr, C.ulong(strLen), substrPtr, C.ulong(substrLen))) - if ret == nil { + matchPtr := unsafe.Pointer(C.sz_find(strPtr, C.ulong(strLen), substrPtr, C.ulong(substrLen))) + if matchPtr == nil { break } count += 1 - strLen -= (substrLen + int64(uintptr(ret)-uintptr(unsafe.Pointer(strPtr)))) - strPtr = (*C.char)(unsafe.Add(ret, substrLen)) + strLen -= (substrLen + int64(uintptr(matchPtr)-uintptr(unsafe.Pointer(strPtr)))) + strPtr = (*C.char)(unsafe.Add(matchPtr, substrLen)) } } diff --git a/scripts/test.go b/scripts/test.go index 30418f91..0a63575f 100644 --- a/scripts/test.go +++ b/scripts/test.go @@ -5,7 +5,7 @@ import ( "runtime" "strings" - sz "../go/stringzilla" + sz "github.com/ashvardanian/stringzilla/golang" ) func assertEqual[T comparable](act T, exp T) int { From 7085d927f2b0949499bb0a2d6b1fee6a0b85d7b3 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 23 Feb 2025 15:34:05 +0000 Subject: [PATCH 6/7] Improve: New GoLang testing suite --- CONTRIBUTING.md | 3 +- golang/lib.go | 18 +++-- golang/lib_test.go | 169 +++++++++++++++++++++++++++++++++++++++++++++ scripts/test.go | 57 --------------- 4 files changed, 183 insertions(+), 64 deletions(-) create mode 100644 golang/lib_test.go delete mode 100644 scripts/test.go diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index df053697..c80c814d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -483,10 +483,11 @@ cmake --build build_golang Then, navigate to the GoLang module root directory and run the tests from there: ```bash +cd golang CGO_CFLAGS="-I$(pwd)/../include" \ CGO_LDFLAGS="-L$(pwd)/../build_golang -lstringzilla_shared" \ LD_LIBRARY_PATH="$(pwd)/../build_golang:$LD_LIBRARY_PATH" \ -go run ../scripts/test.go +go test ``` Alternatively: diff --git a/golang/lib.go b/golang/lib.go index 3d4f9114..6762ebef 100644 --- a/golang/lib.go +++ b/golang/lib.go @@ -45,10 +45,13 @@ func Contains(str string, substr string) bool { // Index returns the index of the first instance of `substr` in `str`, or -1 if `substr` is not present. // https://pkg.go.dev/strings#Index func Index(str string, substr string) int64 { + substrLen := len(substr) + if substrLen == 0 { + return 0 + } strPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) strLen := len(str) substrPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(substr))) - substrLen := len(substr) matchPtr := unsafe.Pointer(C.sz_find(strPtr, C.ulong(strLen), substrPtr, C.ulong(substrLen))) if matchPtr == nil { return -1 @@ -59,10 +62,13 @@ func Index(str string, substr string) int64 { // Index returns the index of the last instance of `substr` in `str`, or -1 if `substr` is not present. // https://pkg.go.dev/strings#LastIndex func LastIndex(str string, substr string) int64 { + substrLen := len(substr) + strLen := int64(len(str)) + if substrLen == 0 { + return strLen + } strPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) - strLen := len(str) substrPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(substr))) - substrLen := len(substr) matchPtr := unsafe.Pointer(C.sz_rfind(strPtr, C.ulong(strLen), substrPtr, C.ulong(substrLen))) if matchPtr == nil { return -1 @@ -133,12 +139,12 @@ func Count(str string, substr string, overlap bool) int64 { substrPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(substr))) substrLen := int64(len(substr)) - if substrLen == 0 { - return 1 + strLen - } if strLen == 0 || strLen < substrLen { return 0 } + if substrLen == 0 { + return 1 + strLen + } count := int64(0) if overlap == true { diff --git a/golang/lib_test.go b/golang/lib_test.go new file mode 100644 index 00000000..561c47c5 --- /dev/null +++ b/golang/lib_test.go @@ -0,0 +1,169 @@ +package sz_test + +import ( + "strings" + "testing" + + sz "github.com/ashvardanian/stringzilla/golang" +) + +// TestContains verifies that the Contains function behaves as expected. +func TestContains(t *testing.T) { + tests := []struct { + s, substr string + want bool + }{ + {"test", "s", true}, + {"test", "test", true}, + {"test", "zest", false}, + {"test", "z", false}, + } + + for _, tt := range tests { + if got := sz.Contains(tt.s, tt.substr); got != tt.want { + t.Errorf("Contains(%q, %q) = %v, want %v", tt.s, tt.substr, got, tt.want) + } + } +} + +// TestIndex compares our binding's Index against the standard strings.Index. +func TestIndex(t *testing.T) { + // We'll use both a long string and some simple cases. + longStr := strings.Repeat("0123456789", 100000) + "something" + tests := []struct { + s, substr string + }{ + {longStr, "some"}, + {"test", ""}, + {"test", "t"}, + {"test", "s"}, + {"test", "z"}, + } + + for _, tt := range tests { + std := strings.Index(tt.s, tt.substr) + got := int(sz.Index(tt.s, tt.substr)) + if got != std { + t.Errorf("Index(%q, %q) = %d, want %d", tt.s, tt.substr, got, std) + } + } +} + +// TestLastIndex compares our binding's LastIndex against the standard strings.LastIndex. +func TestLastIndex(t *testing.T) { + tests := []struct { + s, substr string + }{ + {"test", "t"}, + {"test", "s"}, + {"test", ""}, + {"test", "z"}, + } + + for _, tt := range tests { + std := strings.LastIndex(tt.s, tt.substr) + got := int(sz.LastIndex(tt.s, tt.substr)) + if got != std { + t.Errorf("LastIndex(%q, %q) = %d, want %d", tt.s, tt.substr, got, std) + } + } +} + +// TestIndexByte compares our binding's IndexByte against the standard strings.IndexByte. +func TestIndexByte(t *testing.T) { + tests := []struct { + s string + c byte + }{ + {"test", 't'}, + {"test", 's'}, + {"test", 'z'}, + } + + for _, tt := range tests { + std := strings.IndexByte(tt.s, tt.c) + got := int(sz.IndexByte(tt.s, tt.c)) + if got != std { + t.Errorf("IndexByte(%q, %q) = %d, want %d", tt.s, string(tt.c), got, std) + } + } +} + +// TestLastIndexByte compares our binding's LastIndexByte against the standard strings.LastIndexByte. +func TestLastIndexByte(t *testing.T) { + tests := []struct { + s string + c byte + }{ + {"test", 't'}, + {"test", 's'}, + {"test", 'z'}, + } + + for _, tt := range tests { + std := strings.LastIndexByte(tt.s, tt.c) + got := int(sz.LastIndexByte(tt.s, tt.c)) + if got != std { + t.Errorf("LastIndexByte(%q, %q) = %d, want %d", tt.s, string(tt.c), got, std) + } + } +} + +// TestIndexAny compares our binding's IndexAny against the standard strings.IndexAny. +func TestIndexAny(t *testing.T) { + tests := []struct { + s, charset string + }{ + {"test", "st"}, + {"west east", "ta"}, + {"test", "z"}, + } + + for _, tt := range tests { + std := strings.IndexAny(tt.s, tt.charset) + got := int(sz.IndexAny(tt.s, tt.charset)) + if got != std { + t.Errorf("IndexAny(%q, %q) = %d, want %d", tt.s, tt.charset, got, std) + } + } +} + +// TestLastIndexAny compares our binding's LastIndexAny against the standard strings.LastIndexAny. +func TestLastIndexAny(t *testing.T) { + tests := []struct { + s, charset string + }{ + {"test", "st"}, + {"west east", "ta"}, + {"test", "z"}, + } + + for _, tt := range tests { + std := strings.LastIndexAny(tt.s, tt.charset) + got := int(sz.LastIndexAny(tt.s, tt.charset)) + if got != std { + t.Errorf("LastIndexAny(%q, %q) = %d, want %d", tt.s, tt.charset, got, std) + } + } +} + +// TestCount verifies the Count function for overlapping and non-overlapping cases. +func TestCount(t *testing.T) { + tests := []struct { + s, substr string + overlap bool + want int + }{ + {"aaaaa", "a", false, 5}, + {"aaaaa", "aa", false, 2}, + {"aaaaa", "aa", true, 4}, + {"", "", false, 0}, // depending on your intended behavior, adjust as needed + } + + for _, tt := range tests { + got := int(sz.Count(tt.s, tt.substr, tt.overlap)) + if got != tt.want { + t.Errorf("Count(%q, %q, %v) = %d, want %d", tt.s, tt.substr, tt.overlap, got, tt.want) + } + } +} diff --git a/scripts/test.go b/scripts/test.go deleted file mode 100644 index 0a63575f..00000000 --- a/scripts/test.go +++ /dev/null @@ -1,57 +0,0 @@ -package main - -import ( - "fmt" - "runtime" - "strings" - - sz "github.com/ashvardanian/stringzilla/golang" -) - -func assertEqual[T comparable](act T, exp T) int { - if exp == act { - return 0 - } - _, _, line, _ := runtime.Caller(1) - fmt.Println("") - fmt.Println(" ERROR line ", line, " expected (", exp, ") is not equal to actual (", act, ")") - return 1 -} - -func main() { - - str := strings.Repeat("0123456789", 100000) + "something" - pat := "some" - ret := 0 - - fmt.Print("Contains ... ") - ret |= assertEqual(sz.Contains("test", "s"), true) - ret |= assertEqual(sz.Contains("test", "test"), true) - ret |= assertEqual(sz.Contains("test", "zest"), false) - ret |= assertEqual(sz.Contains("test", "z"), false) - if ret == 0 { - fmt.Println("successful") - } - - fmt.Print("Index ... ") - assertEqual(strings.Index(str, pat), int(sz.Index(str, pat))) - assertEqual(sz.Index("", ""), 0) - assertEqual(sz.Index("test", ""), 0) - assertEqual(sz.Index("test", "t"), 0) - assertEqual(sz.Index("test", "s"), 2) - fmt.Println("successful") - - fmt.Print("IndexAny ... ") - assertEqual(strings.IndexAny(str, pat), int(sz.IndexAny(str, pat))) - assertEqual(sz.IndexAny("test", "st"), 0) - assertEqual(sz.IndexAny("west east", "ta"), 3) - fmt.Println("successful") - - fmt.Print("Count ... ") - //assertEqual( strings.Count( str, pat ), int(sz.Count( str,pat,false )) ) - assertEqual(sz.Count("aaaaa", "a", false), 5) - assertEqual(sz.Count("aaaaa", "aa", false), 2) - assertEqual(sz.Count("aaaaa", "aa", true), 4) - fmt.Println("successful") - -} From 4fbdd0bc2645c54939ab504d4d56e67527673343 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 23 Feb 2025 16:24:54 +0000 Subject: [PATCH 7/7] Improve: Extended GoLang benchmarks --- CONTRIBUTING.md | 10 +++ golang/lib.go | 10 +-- scripts/bench.go | 155 +++++++++++++++++++++++++++++++++-------------- 3 files changed, 125 insertions(+), 50 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c80c814d..1cf04d86 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -490,6 +490,16 @@ LD_LIBRARY_PATH="$(pwd)/../build_golang:$LD_LIBRARY_PATH" \ go test ``` +To benchmark: + +```bash +cd golang +CGO_CFLAGS="-I$(pwd)/../include" \ +CGO_LDFLAGS="-L$(pwd)/../build_golang -lstringzilla_shared" \ +LD_LIBRARY_PATH="$(pwd)/../build_golang:$LD_LIBRARY_PATH" \ +go run ../scripts/bench.go --input ../leipzig1M.txt +``` + Alternatively: ```bash diff --git a/golang/lib.go b/golang/lib.go index 6762ebef..4152d221 100644 --- a/golang/lib.go +++ b/golang/lib.go @@ -15,16 +15,16 @@ package sz // #cgo CFLAGS: -O3 // #cgo LDFLAGS: -L. -L/usr/local/lib -lstringzilla_shared // #cgo noescape sz_find -// #cgo noescape sz_find_byte -// #cgo noescape sz_rfind -// #cgo noescape sz_rfind_byte -// #cgo noescape sz_find_char_from -// #cgo noescape sz_rfind_char_from // #cgo nocallback sz_find +// #cgo noescape sz_find_byte // #cgo nocallback sz_find_byte +// #cgo noescape sz_rfind // #cgo nocallback sz_rfind +// #cgo noescape sz_rfind_byte // #cgo nocallback sz_rfind_byte +// #cgo noescape sz_find_char_from // #cgo nocallback sz_find_char_from +// #cgo noescape sz_rfind_char_from // #cgo nocallback sz_rfind_char_from // #define SZ_DYNAMIC_DISPATCH 1 // #include diff --git a/scripts/bench.go b/scripts/bench.go index a63380ea..ba02c9bf 100644 --- a/scripts/bench.go +++ b/scripts/bench.go @@ -1,70 +1,135 @@ package main import ( + "flag" "fmt" + "math/rand" + "os" "strings" + "testing" "time" - sz "../go/stringzilla" + sz "github.com/ashvardanian/stringzilla/golang" ) -func main() { +var sink any //? Global sink to defeat dead-code elimination - str := strings.Repeat("0123456789", 10000) + "something" - pat := "some" +// Repeats a certain function `f` multiple times and prints the benchmark results. +func runBenchmark[T any](name string, f func() T) { + benchResult := testing.Benchmark(func(b *testing.B) { + for i := 0; i < b.N; i++ { + sink = f() + } + }) + fmt.Printf("%-30s: %s\n", name, benchResult.String()) +} - fmt.Println("Contains") - t := time.Now() - for i := 0; i < 1; i++ { - strings.Contains(str, pat) - } - fmt.Println(" ", time.Since(t), "\tstrings.Contains") +func main() { - t = time.Now() - for i := 0; i < 1; i++ { - sz.Contains(str, pat) - } - fmt.Println(" ", time.Since(t), "\tsz.Contains") + // Define command-line flags. + inputPath := flag.String("input", "", "Path to input file for benchmarking. (Required)") + seedInt := flag.Int64("seed", 0, "Seed for the random number generator. If 0, the current time is used.") + splitMode := flag.String("split", "tokens", "How to split input file: 'tokens' (default) or 'lines'.") + flag.Parse() - fmt.Println("Index") - t = time.Now() - for i := 0; i < 1; i++ { - strings.Index(str, pat) + // Ensure input file is provided. + if *inputPath == "" { + fmt.Fprintln(os.Stderr, "Error: input file must be specified using the -input flag.") + flag.Usage() + os.Exit(1) } - fmt.Println(" ", time.Since(t), "\tstrings.Index") - t = time.Now() - for i := 0; i < 1; i++ { - sz.Index(str, pat) + // Read input data from file. + bytes, err := os.ReadFile(*inputPath) + if err != nil { + fmt.Fprintf(os.Stderr, "Error reading input file: %v\n", err) + os.Exit(1) } - fmt.Println(" ", time.Since(t), "\tsz.Index") + data := string(bytes) + fmt.Printf("Benchmarking on `%s` with seed %d.\n", *inputPath, *seedInt) + fmt.Printf("Total input length: %d\n", len(data)) - fmt.Println("IndexAny") - t = time.Now() - for i := 0; i < 1; i++ { - strings.IndexAny(str, pat) + // Split the data into items based on the chosen mode. + var items []string + switch *splitMode { + case "lines": + rawLines := strings.Split(data, "\n") + // Filter out empty lines. + for _, line := range rawLines { + if line != "" { + items = append(items, line) + } + } + if len(items) == 0 { + items = []string{"default"} + } + // Print line statistics. + totalLen := 0 + for _, line := range items { + totalLen += len(line) + } + fmt.Printf("Total lines: %d\n", len(items)) + fmt.Printf("Average line length: %.2f\n", float64(totalLen)/float64(len(items))) + default: // "tokens" or any other value defaults to token mode. + items = strings.Fields(data) + if len(items) == 0 { + items = []string{"default"} + } + fmt.Printf("Total tokens: %d\n", len(items)) + fmt.Printf("Average token length: %.2f\n", float64(len(data))/float64(len(items))) } - fmt.Println(" ", time.Since(t), "\tstrings.IndexAny") - t = time.Now() - for i := 0; i < 1; i++ { - sz.IndexAny(str, pat) + // In Go, a string is represented as a (length, data) pair. If you pass a string around, + // Go will copy the length and the pointer but not the data pointed to. + // It's problematic for our benchmark as it makes substring operations meaningless - + // just comparing if a pointer falls in the range. + // To avoid that, let's copy strings to `[]byte` and back to force a new allocation. + for i, item := range items { + items[i] = string([]byte(item)) } - fmt.Println(" ", time.Since(t), "\tsz.IndexAny") - str = strings.Repeat("0123456789", 100000) + "something" - pat = "123456789" - fmt.Println("Count") - t = time.Now() - for i := 0; i < 1; i++ { - strings.Count(str, pat) + // Create a seeded reproducible random number generator. + if *seedInt == 0 { + *seedInt = time.Now().UnixNano() } - fmt.Println(" ", time.Since(t), "\tstrings.Count") - - t = time.Now() - for i := 0; i < 1; i++ { - sz.Count(str, pat, false) + generator := rand.New(rand.NewSource(*seedInt)) + randomItem := func() string { + return items[generator.Intn(len(items))] } - fmt.Println(" ", time.Since(t), "\tsz.Count") + fmt.Println("Running benchmark using `testing.Benchmark`.") + + runBenchmark("strings.Contains", func() bool { + return strings.Contains(data, randomItem()) + }) + runBenchmark("sz.Contains", func() bool { + return sz.Contains(data, randomItem()) + }) + runBenchmark("strings.Index", func() int { + return strings.Index(data, randomItem()) + }) + runBenchmark("sz.Index", func() int64 { + return sz.Index(data, randomItem()) + }) + runBenchmark("strings.LastIndex", func() int { + return strings.LastIndex(data, randomItem()) + }) + runBenchmark("sz.LastIndex", func() int64 { + return sz.LastIndex(data, randomItem()) + }) + runBenchmark("strings.IndexAny", func() int { + return strings.IndexAny(randomItem(), "*^") + }) + runBenchmark("sz.IndexAny", func() int64 { + return sz.IndexAny(randomItem(), "*^") + }) + runBenchmark("strings.Count", func() int { + return strings.Count(data, randomItem()) + }) + runBenchmark("sz.Count (non-overlap)", func() int64 { + return sz.Count(data, randomItem(), false) + }) + runBenchmark("sz.Count (overlap)", func() int64 { + return sz.Count(data, randomItem(), true) + }) }