diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..e0c79660 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,6 @@ +# GitHub's Linguist doesn't properly classify many languages +# https://github.com/github-linguist/linguist/blob/main/docs/overrides.md +*.h linguist-language=C +*.c linguist-language=C +*.hpp lingujson-language=C++ +*.cpp lingujson-language=C++ diff --git a/.gitignore b/.gitignore index fa4a8d0d..cbd6488a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,8 @@ build/ build_debug/ build_release/ +build_go/ +build_golang/ build_artifacts* # Yes, everyone loves keeping this file in the history. diff --git a/.vscode/settings.json b/.vscode/settings.json index 980956d1..ffdb7ff1 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -86,6 +86,8 @@ "Needleman", "newfunc", "NOARGS", + "nocallback", + "noescape", "noexcept", "NOMINMAX", "NOTIMPLEMENTED", diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 524d6c49..1cf04d86 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -471,6 +471,43 @@ cargo package --list --allow-dirty If you want to run benchmarks against third-party implementations, check out the [`ashvardanian/memchr_vs_stringzilla`](https://github.com/ashvardanian/memchr_vs_stringzilla/) repository. +## Contributing in GoLang + +First, precompile the C library: + +```bash +cmake -D STRINGZILLA_BUILD_SHARED=1 -D STRINGZILLA_BUILD_TEST=0 -D STRINGZILLA_BUILD_BENCHMARK=0 -B build_golang +cmake --build build_golang +``` + +Then, navigate to the GoLang module root directory and run the tests from there: + +```bash +cd golang +CGO_CFLAGS="-I$(pwd)/../include" \ +CGO_LDFLAGS="-L$(pwd)/../build_golang -lstringzilla_shared" \ +LD_LIBRARY_PATH="$(pwd)/../build_golang:$LD_LIBRARY_PATH" \ +go test +``` + +To benchmark: + +```bash +cd golang +CGO_CFLAGS="-I$(pwd)/../include" \ +CGO_LDFLAGS="-L$(pwd)/../build_golang -lstringzilla_shared" \ +LD_LIBRARY_PATH="$(pwd)/../build_golang:$LD_LIBRARY_PATH" \ +go run ../scripts/bench.go --input ../leipzig1M.txt +``` + +Alternatively: + +```bash +export GO111MODULE="off" +go run scripts/test.go +go run scripts/bench.go +``` + ## General Recommendations ### Operations Not Worth Optimizing diff --git a/golang/go.mod b/golang/go.mod new file mode 100644 index 00000000..fc14b40e --- /dev/null +++ b/golang/go.mod @@ -0,0 +1,3 @@ +module github.com/ashvardanian/stringzilla/golang + +go 1.24 diff --git a/golang/lib.go b/golang/lib.go new file mode 100644 index 00000000..4152d221 --- /dev/null +++ b/golang/lib.go @@ -0,0 +1,173 @@ +// StringZilla is a SIMD-accelerated string library modern CPUs, written in C 99, +// and using AVX2, AVX512, Arm NEON, and SVE intrinsics to accelerate processing. +// +// The GoLang binding is intended to provide a simple interface to a precompiled +// shared library, available on GitHub: https://github.com/ashvardanian/stringzilla +// +// It requires Go 1.24 or newer to leverage the `cGo` `noescape` and `nocallback` +// directives. Without those the latency of calling C functions from Go is too high +// to be useful for string processing. +// +// Unlike the native Go `strings` package, StringZilla primarily targets byte-level +// binary data processing, with less emphasis on UTF-8 and locale-specific tasks. +package sz + +// #cgo CFLAGS: -O3 +// #cgo LDFLAGS: -L. -L/usr/local/lib -lstringzilla_shared +// #cgo noescape sz_find +// #cgo nocallback sz_find +// #cgo noescape sz_find_byte +// #cgo nocallback sz_find_byte +// #cgo noescape sz_rfind +// #cgo nocallback sz_rfind +// #cgo noescape sz_rfind_byte +// #cgo nocallback sz_rfind_byte +// #cgo noescape sz_find_char_from +// #cgo nocallback sz_find_char_from +// #cgo noescape sz_rfind_char_from +// #cgo nocallback sz_rfind_char_from +// #define SZ_DYNAMIC_DISPATCH 1 +// #include +import "C" +import "unsafe" + +// Contains reports whether `substr` is within `str`. +// https://pkg.go.dev/strings#Contains +func Contains(str string, substr string) bool { + strPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + strLen := len(str) + substrPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(substr))) + substrLen := len(substr) + matchPtr := unsafe.Pointer(C.sz_find(strPtr, C.ulong(strLen), substrPtr, C.ulong(substrLen))) + return matchPtr != nil +} + +// Index returns the index of the first instance of `substr` in `str`, or -1 if `substr` is not present. +// https://pkg.go.dev/strings#Index +func Index(str string, substr string) int64 { + substrLen := len(substr) + if substrLen == 0 { + return 0 + } + strPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + strLen := len(str) + substrPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(substr))) + matchPtr := unsafe.Pointer(C.sz_find(strPtr, C.ulong(strLen), substrPtr, C.ulong(substrLen))) + if matchPtr == nil { + return -1 + } + return int64(uintptr(matchPtr) - uintptr(unsafe.Pointer(strPtr))) +} + +// Index returns the index of the last instance of `substr` in `str`, or -1 if `substr` is not present. +// https://pkg.go.dev/strings#LastIndex +func LastIndex(str string, substr string) int64 { + substrLen := len(substr) + strLen := int64(len(str)) + if substrLen == 0 { + return strLen + } + strPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + substrPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(substr))) + matchPtr := unsafe.Pointer(C.sz_rfind(strPtr, C.ulong(strLen), substrPtr, C.ulong(substrLen))) + if matchPtr == nil { + return -1 + } + return int64(uintptr(matchPtr) - uintptr(unsafe.Pointer(strPtr))) +} + +// Index returns the index of the first instance of a byte in `str`, or -1 if a byte is not present. +// https://pkg.go.dev/strings#IndexByte +func IndexByte(str string, c byte) int64 { + strPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + strLen := len(str) + cPtr := (*C.char)(unsafe.Pointer(&c)) + matchPtr := unsafe.Pointer(C.sz_find_byte(strPtr, C.ulong(strLen), cPtr)) + if matchPtr == nil { + return -1 + } + return int64(uintptr(matchPtr) - uintptr(unsafe.Pointer(strPtr))) +} + +// Index returns the index of the last instance of a byte in `str`, or -1 if a byte is not present. +// https://pkg.go.dev/strings#LastIndexByte +func LastIndexByte(str string, c byte) int64 { + strPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + strLen := len(str) + cPtr := (*C.char)(unsafe.Pointer(&c)) + matchPtr := unsafe.Pointer(C.sz_rfind_byte(strPtr, C.ulong(strLen), cPtr)) + if matchPtr == nil { + return -1 + } + return int64(uintptr(matchPtr) - uintptr(unsafe.Pointer(strPtr))) +} + +// Index returns the index of the first instance of any byte from `substr` in `str`, or -1 if none are present. +// https://pkg.go.dev/strings#IndexAny +func IndexAny(str string, substr string) int64 { + strPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + strLen := len(str) + substrPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(substr))) + substrLen := len(substr) + matchPtr := unsafe.Pointer(C.sz_find_char_from(strPtr, C.ulong(strLen), substrPtr, C.ulong(substrLen))) + if matchPtr == nil { + return -1 + } + return int64(uintptr(matchPtr) - uintptr(unsafe.Pointer(strPtr))) +} + +// Index returns the index of the last instance of any byte from `substr` in `str`, or -1 if none are present. +// https://pkg.go.dev/strings#LastIndexAny +func LastIndexAny(str string, substr string) int64 { + strPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + strLen := len(str) + substrPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(substr))) + substrLen := len(substr) + matchPtr := unsafe.Pointer(C.sz_rfind_char_from(strPtr, C.ulong(strLen), substrPtr, C.ulong(substrLen))) + if matchPtr == nil { + return -1 + } + return int64(uintptr(matchPtr) - uintptr(unsafe.Pointer(strPtr))) +} + +// Count returns the number of overlapping or non-overlapping instances of `substr` in `str`. +// If `substr` is an empty string, returns 1 + the length of the `str`. +// https://pkg.go.dev/strings#Count +func Count(str string, substr string, overlap bool) int64 { + strPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + strLen := int64(len(str)) + substrPtr := (*C.char)(unsafe.Pointer(unsafe.StringData(substr))) + substrLen := int64(len(substr)) + + if strLen == 0 || strLen < substrLen { + return 0 + } + if substrLen == 0 { + return 1 + strLen + } + + count := int64(0) + if overlap == true { + for strLen > 0 { + matchPtr := unsafe.Pointer(C.sz_find(strPtr, C.ulong(strLen), substrPtr, C.ulong(substrLen))) + if matchPtr == nil { + break + } + count += 1 + strLen -= (1 + int64(uintptr(matchPtr)-uintptr(unsafe.Pointer(strPtr)))) + strPtr = (*C.char)(unsafe.Add(matchPtr, 1)) + } + } else { + for strLen > 0 { + matchPtr := unsafe.Pointer(C.sz_find(strPtr, C.ulong(strLen), substrPtr, C.ulong(substrLen))) + if matchPtr == nil { + break + } + count += 1 + strLen -= (substrLen + int64(uintptr(matchPtr)-uintptr(unsafe.Pointer(strPtr)))) + strPtr = (*C.char)(unsafe.Add(matchPtr, substrLen)) + } + } + + return count +} diff --git a/golang/lib_test.go b/golang/lib_test.go new file mode 100644 index 00000000..561c47c5 --- /dev/null +++ b/golang/lib_test.go @@ -0,0 +1,169 @@ +package sz_test + +import ( + "strings" + "testing" + + sz "github.com/ashvardanian/stringzilla/golang" +) + +// TestContains verifies that the Contains function behaves as expected. +func TestContains(t *testing.T) { + tests := []struct { + s, substr string + want bool + }{ + {"test", "s", true}, + {"test", "test", true}, + {"test", "zest", false}, + {"test", "z", false}, + } + + for _, tt := range tests { + if got := sz.Contains(tt.s, tt.substr); got != tt.want { + t.Errorf("Contains(%q, %q) = %v, want %v", tt.s, tt.substr, got, tt.want) + } + } +} + +// TestIndex compares our binding's Index against the standard strings.Index. +func TestIndex(t *testing.T) { + // We'll use both a long string and some simple cases. + longStr := strings.Repeat("0123456789", 100000) + "something" + tests := []struct { + s, substr string + }{ + {longStr, "some"}, + {"test", ""}, + {"test", "t"}, + {"test", "s"}, + {"test", "z"}, + } + + for _, tt := range tests { + std := strings.Index(tt.s, tt.substr) + got := int(sz.Index(tt.s, tt.substr)) + if got != std { + t.Errorf("Index(%q, %q) = %d, want %d", tt.s, tt.substr, got, std) + } + } +} + +// TestLastIndex compares our binding's LastIndex against the standard strings.LastIndex. +func TestLastIndex(t *testing.T) { + tests := []struct { + s, substr string + }{ + {"test", "t"}, + {"test", "s"}, + {"test", ""}, + {"test", "z"}, + } + + for _, tt := range tests { + std := strings.LastIndex(tt.s, tt.substr) + got := int(sz.LastIndex(tt.s, tt.substr)) + if got != std { + t.Errorf("LastIndex(%q, %q) = %d, want %d", tt.s, tt.substr, got, std) + } + } +} + +// TestIndexByte compares our binding's IndexByte against the standard strings.IndexByte. +func TestIndexByte(t *testing.T) { + tests := []struct { + s string + c byte + }{ + {"test", 't'}, + {"test", 's'}, + {"test", 'z'}, + } + + for _, tt := range tests { + std := strings.IndexByte(tt.s, tt.c) + got := int(sz.IndexByte(tt.s, tt.c)) + if got != std { + t.Errorf("IndexByte(%q, %q) = %d, want %d", tt.s, string(tt.c), got, std) + } + } +} + +// TestLastIndexByte compares our binding's LastIndexByte against the standard strings.LastIndexByte. +func TestLastIndexByte(t *testing.T) { + tests := []struct { + s string + c byte + }{ + {"test", 't'}, + {"test", 's'}, + {"test", 'z'}, + } + + for _, tt := range tests { + std := strings.LastIndexByte(tt.s, tt.c) + got := int(sz.LastIndexByte(tt.s, tt.c)) + if got != std { + t.Errorf("LastIndexByte(%q, %q) = %d, want %d", tt.s, string(tt.c), got, std) + } + } +} + +// TestIndexAny compares our binding's IndexAny against the standard strings.IndexAny. +func TestIndexAny(t *testing.T) { + tests := []struct { + s, charset string + }{ + {"test", "st"}, + {"west east", "ta"}, + {"test", "z"}, + } + + for _, tt := range tests { + std := strings.IndexAny(tt.s, tt.charset) + got := int(sz.IndexAny(tt.s, tt.charset)) + if got != std { + t.Errorf("IndexAny(%q, %q) = %d, want %d", tt.s, tt.charset, got, std) + } + } +} + +// TestLastIndexAny compares our binding's LastIndexAny against the standard strings.LastIndexAny. +func TestLastIndexAny(t *testing.T) { + tests := []struct { + s, charset string + }{ + {"test", "st"}, + {"west east", "ta"}, + {"test", "z"}, + } + + for _, tt := range tests { + std := strings.LastIndexAny(tt.s, tt.charset) + got := int(sz.LastIndexAny(tt.s, tt.charset)) + if got != std { + t.Errorf("LastIndexAny(%q, %q) = %d, want %d", tt.s, tt.charset, got, std) + } + } +} + +// TestCount verifies the Count function for overlapping and non-overlapping cases. +func TestCount(t *testing.T) { + tests := []struct { + s, substr string + overlap bool + want int + }{ + {"aaaaa", "a", false, 5}, + {"aaaaa", "aa", false, 2}, + {"aaaaa", "aa", true, 4}, + {"", "", false, 0}, // depending on your intended behavior, adjust as needed + } + + for _, tt := range tests { + got := int(sz.Count(tt.s, tt.substr, tt.overlap)) + if got != tt.want { + t.Errorf("Count(%q, %q, %v) = %d, want %d", tt.s, tt.substr, tt.overlap, got, tt.want) + } + } +} diff --git a/scripts/bench.go b/scripts/bench.go new file mode 100644 index 00000000..ba02c9bf --- /dev/null +++ b/scripts/bench.go @@ -0,0 +1,135 @@ +package main + +import ( + "flag" + "fmt" + "math/rand" + "os" + "strings" + "testing" + "time" + + sz "github.com/ashvardanian/stringzilla/golang" +) + +var sink any //? Global sink to defeat dead-code elimination + +// Repeats a certain function `f` multiple times and prints the benchmark results. +func runBenchmark[T any](name string, f func() T) { + benchResult := testing.Benchmark(func(b *testing.B) { + for i := 0; i < b.N; i++ { + sink = f() + } + }) + fmt.Printf("%-30s: %s\n", name, benchResult.String()) +} + +func main() { + + // Define command-line flags. + inputPath := flag.String("input", "", "Path to input file for benchmarking. (Required)") + seedInt := flag.Int64("seed", 0, "Seed for the random number generator. If 0, the current time is used.") + splitMode := flag.String("split", "tokens", "How to split input file: 'tokens' (default) or 'lines'.") + flag.Parse() + + // Ensure input file is provided. + if *inputPath == "" { + fmt.Fprintln(os.Stderr, "Error: input file must be specified using the -input flag.") + flag.Usage() + os.Exit(1) + } + + // Read input data from file. + bytes, err := os.ReadFile(*inputPath) + if err != nil { + fmt.Fprintf(os.Stderr, "Error reading input file: %v\n", err) + os.Exit(1) + } + data := string(bytes) + fmt.Printf("Benchmarking on `%s` with seed %d.\n", *inputPath, *seedInt) + fmt.Printf("Total input length: %d\n", len(data)) + + // Split the data into items based on the chosen mode. + var items []string + switch *splitMode { + case "lines": + rawLines := strings.Split(data, "\n") + // Filter out empty lines. + for _, line := range rawLines { + if line != "" { + items = append(items, line) + } + } + if len(items) == 0 { + items = []string{"default"} + } + // Print line statistics. + totalLen := 0 + for _, line := range items { + totalLen += len(line) + } + fmt.Printf("Total lines: %d\n", len(items)) + fmt.Printf("Average line length: %.2f\n", float64(totalLen)/float64(len(items))) + default: // "tokens" or any other value defaults to token mode. + items = strings.Fields(data) + if len(items) == 0 { + items = []string{"default"} + } + fmt.Printf("Total tokens: %d\n", len(items)) + fmt.Printf("Average token length: %.2f\n", float64(len(data))/float64(len(items))) + } + + // In Go, a string is represented as a (length, data) pair. If you pass a string around, + // Go will copy the length and the pointer but not the data pointed to. + // It's problematic for our benchmark as it makes substring operations meaningless - + // just comparing if a pointer falls in the range. + // To avoid that, let's copy strings to `[]byte` and back to force a new allocation. + for i, item := range items { + items[i] = string([]byte(item)) + } + + // Create a seeded reproducible random number generator. + if *seedInt == 0 { + *seedInt = time.Now().UnixNano() + } + generator := rand.New(rand.NewSource(*seedInt)) + randomItem := func() string { + return items[generator.Intn(len(items))] + } + + fmt.Println("Running benchmark using `testing.Benchmark`.") + + runBenchmark("strings.Contains", func() bool { + return strings.Contains(data, randomItem()) + }) + runBenchmark("sz.Contains", func() bool { + return sz.Contains(data, randomItem()) + }) + runBenchmark("strings.Index", func() int { + return strings.Index(data, randomItem()) + }) + runBenchmark("sz.Index", func() int64 { + return sz.Index(data, randomItem()) + }) + runBenchmark("strings.LastIndex", func() int { + return strings.LastIndex(data, randomItem()) + }) + runBenchmark("sz.LastIndex", func() int64 { + return sz.LastIndex(data, randomItem()) + }) + runBenchmark("strings.IndexAny", func() int { + return strings.IndexAny(randomItem(), "*^") + }) + runBenchmark("sz.IndexAny", func() int64 { + return sz.IndexAny(randomItem(), "*^") + }) + runBenchmark("strings.Count", func() int { + return strings.Count(data, randomItem()) + }) + runBenchmark("sz.Count (non-overlap)", func() int64 { + return sz.Count(data, randomItem(), false) + }) + runBenchmark("sz.Count (overlap)", func() int64 { + return sz.Count(data, randomItem(), true) + }) +} diff --git a/scripts/build.sh b/scripts/build.sh deleted file mode 100755 index 600e5758..00000000 --- a/scripts/build.sh +++ /dev/null @@ -1,65 +0,0 @@ -#!/bin/bash -# This Bash script compiles the CMake-based project with different compilers for different verrsions of C++ -# This is what should happen if only GCC 12 is installed and we are running on Sapphire Rapids. -# -# cmake -DCMAKE_BUILD_TYPE=Release -DSTRINGZILLA_BUILD_BENCHMARK=1 \ -# -DCMAKE_CXX_COMPILER=g++-12 -DCMAKE_C_COMPILER=gcc-12 \ -# -DSTRINGZILLA_TARGET_ARCH="sandybridge" -B build_release/gcc-12-sandybridge && \ -# cmake --build build_release/gcc-12-sandybridge --config Release -# cmake -DCMAKE_BUILD_TYPE=Release -DSTRINGZILLA_BUILD_BENCHMARK=1 \ -# -DCMAKE_CXX_COMPILER=g++-12 -DCMAKE_C_COMPILER=gcc-12 \ -# -DSTRINGZILLA_TARGET_ARCH="haswell" -B build_release/gcc-12-haswell && \ -# cmake --build build_release/gcc-12-haswell --config Release -# cmake -DCMAKE_BUILD_TYPE=Release -DSTRINGZILLA_BUILD_BENCHMARK=1 \ -# -DCMAKE_CXX_COMPILER=g++-12 -DCMAKE_C_COMPILER=gcc-12 \ -# -DSTRINGZILLA_TARGET_ARCH="sapphirerapids" -B build_release/gcc-12-sapphirerapids && \ -# cmake --build build_release/gcc-12-sapphirerapids --config Release - -# Array of target architectures -declare -a architectures=("sandybridge" "haswell" "sapphirerapids") - -# Function to get installed versions of a compiler -get_versions() { - local compiler_prefix=$1 - local versions=() - - echo "Checking for compilers in /usr/bin with prefix: $compiler_prefix" - - # Check if the directory /usr/bin exists and is a directory - if [ -d "/usr/bin" ]; then - for version in /usr/bin/${compiler_prefix}-*; do - echo "Checking: $version" - if [[ -x "$version" ]]; then - local ver=${version##*-} - echo "Found compiler version: $ver" - versions+=("$ver") - fi - done - else - echo "/usr/bin does not exist or is not a directory" - fi - - echo ${versions[@]} -} - -# Get installed versions of GCC and Clang -gcc_versions=$(get_versions gcc) -clang_versions=$(get_versions clang) - -# Compile for each combination of compiler and architecture -for arch in "${ARCHS[@]}"; do - for gcc_version in $gcc_versions; do - cmake -DCMAKE_BUILD_TYPE=Release -DSTRINGZILLA_BUILD_BENCHMARK=1 \ - -DCMAKE_CXX_COMPILER=g++-$gcc_version -DCMAKE_C_COMPILER=gcc-$gcc_version \ - -DSTRINGZILLA_TARGET_ARCH="$arch" -B "build_release/gcc-$gcc_version-$arch" && \ - cmake --build "build_release/gcc-$gcc_version-$arch" --config Release - done - - for clang_version in $clang_versions; do - cmake -DCMAKE_BUILD_TYPE=Release -DSTRINGZILLA_BUILD_BENCHMARK=1 \ - -DCMAKE_CXX_COMPILER=clang++-$clang_version -DCMAKE_C_COMPILER=clang-$clang_version \ - -DSTRINGZILLA_TARGET_ARCH="$arch" -B "build_release/clang-$clang_version-$arch" && \ - cmake --build "build_release/clang-$clang_version-$arch" --config Release - done -done -