Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 159 additions & 55 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import glob
import os
import platform
import shutil
import subprocess
import sys
from sysconfig import get_config_var
from typing import Any
Expand All @@ -24,6 +26,7 @@

LIBRARY = "opteryx"


def is_mac(): # pragma: no cover
return platform.system().lower() == "darwin"

Expand All @@ -32,7 +35,111 @@ def is_win(): # pragma: no cover
return platform.system().lower() == "windows"


REQUESTED_COMMANDS = {arg.lower() for arg in sys.argv[1:] if arg and not arg.startswith('-')}
def check_rust_availability(): # pragma: no cover
"""
Check if Rust toolchain is available and return paths.

Returns:
tuple: (rustc_path, cargo_path, rust_env) if available

Raises:
SystemExit: If Rust is not available with a helpful error message
"""
# Check environment variables first, treating empty strings as None
rustc_path = os.environ.get("RUSTC") or None
cargo_path = os.environ.get("CARGO") or None

# If not set via environment, check PATH
if not rustc_path:
rustc_path = shutil.which("rustc")
if not cargo_path:
cargo_path = shutil.which("cargo")

# Validate that paths exist if provided
if rustc_path and not os.path.isfile(rustc_path):
print(
f"\033[38;2;255;208;0mWarning:\033[0m RUSTC environment variable points to non-existent file: {rustc_path}",
file=sys.stderr,
)
rustc_path = None
if cargo_path and not os.path.isfile(cargo_path):
print(
f"\033[38;2;255;208;0mWarning:\033[0m CARGO environment variable points to non-existent file: {cargo_path}",
file=sys.stderr,
)
cargo_path = None

if not rustc_path or not cargo_path:
error_msg = """
\033[38;2;255;85;85m╔═══════════════════════════════════════════════════════════════════════════╗
║ RUST TOOLCHAIN NOT FOUND ║
╚═══════════════════════════════════════════════════════════════════════════╝\033[0m

Opteryx requires the Rust compiler (rustc) and Cargo to build native extensions.

\033[38;2;255;208;0mWhat's missing:\033[0m
"""
if not rustc_path:
error_msg += " • rustc (Rust compiler) not found in PATH\n"
if not cargo_path:
error_msg += " • cargo (Rust package manager) not found in PATH\n"

error_msg += """
\033[38;2;255;208;0mHow to install Rust:\033[0m

1. Visit: https://rustup.rs/
2. Run the installation command for your platform:

\033[38;2;139;233;253mLinux/macOS:\033[0m
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh

\033[38;2;139;233;253mWindows:\033[0m
Download and run: https://win.rustup.rs/

3. After installation, restart your terminal or run:
source $HOME/.cargo/env

4. Verify installation:
rustc --version
cargo --version

\033[38;2;255;208;0mAlternative:\033[0m
If Rust is installed but not in PATH, you can set the CARGO and RUSTC
environment variables before running setup.py:

export CARGO=/path/to/cargo
export RUSTC=/path/to/rustc

\033[38;2;255;85;85m═══════════════════════════════════════════════════════════════════════════\033[0m
"""
print(error_msg, file=sys.stderr)
sys.exit(1)

# Get Rust version for informational purposes
try:
rust_version = subprocess.check_output(
[rustc_path, "--version"], text=True, stderr=subprocess.STDOUT
).strip()
print(f"\033[38;2;139;233;253mFound Rust toolchain:\033[0m {rust_version}")
print(f"\033[38;2;139;233;253mrustc path:\033[0m {rustc_path}")
print(f"\033[38;2;139;233;253mcargo path:\033[0m {cargo_path}")
except (subprocess.CalledProcessError, OSError) as e:
# If we can't run rustc --version, the path might not be a valid rustc binary
print(
f"\033[38;2;255;208;0mWarning:\033[0m Could not verify Rust installation at {rustc_path}: {e}",
file=sys.stderr,
)

# Create environment dict with explicit paths
rust_env = {
"RUSTC": rustc_path,
"CARGO": cargo_path,
}

return rustc_path, cargo_path, rust_env


REQUESTED_COMMANDS = {arg.lower() for arg in sys.argv[1:] if arg and not arg.startswith("-")}
SHOULD_BUILD_EXTENSIONS = "clean" not in REQUESTED_COMMANDS

if not SHOULD_BUILD_EXTENSIONS:
Expand All @@ -42,7 +149,6 @@ def is_win(): # pragma: no cover
)

if SHOULD_BUILD_EXTENSIONS:

CPP_COMPILE_FLAGS = ["-O3"]
C_COMPILE_FLAGS = ["-O3"]
if is_mac():
Expand Down Expand Up @@ -81,10 +187,15 @@ def is_win(): # pragma: no cover

print("\033[38;2;255;85;85mInclude paths:\033[0m", include_dirs)

# Check Rust availability and get paths
rustc_path, cargo_path, rust_env = check_rust_availability()

def rust_build(setup_kwargs: Dict[str, Any]) -> None:
setup_kwargs.update(
{
"rust_extensions": [RustExtension("opteryx.compute", "Cargo.toml", debug=False)],
"rust_extensions": [
RustExtension("opteryx.compute", "Cargo.toml", debug=False, env=rust_env)
],
"zip_safe": False,
}
)
Expand Down Expand Up @@ -128,20 +239,14 @@ def rust_build(setup_kwargs: Dict[str, Any]) -> None:
),
Extension(
name="opteryx.third_party.alantsd.base64",
sources=[
"opteryx/third_party/alantsd/base64.pyx",
"third_party/alantsd/base64.c"
],
sources=["opteryx/third_party/alantsd/base64.pyx", "third_party/alantsd/base64.c"],
include_dirs=include_dirs + ["third_party/alantsd"],
extra_compile_args=C_COMPILE_FLAGS + ["-std=c99", "-DBASE64_IMPLEMENTATION"],
extra_link_args=["-Lthird_party/alantsd"],
),
Extension(
name="opteryx.third_party.cyan4973.xxhash",
sources=[
"opteryx/third_party/cyan4973/xxhash.pyx",
"third_party/cyan4973/xxhash.c"
],
sources=["opteryx/third_party/cyan4973/xxhash.pyx", "third_party/cyan4973/xxhash.c"],
include_dirs=include_dirs + ["third_party/cyan4973"],
extra_compile_args=C_COMPILE_FLAGS,
extra_link_args=["-Lthird_party/cyan4973"],
Expand Down Expand Up @@ -217,10 +322,7 @@ def rust_build(setup_kwargs: Dict[str, Any]) -> None:
),
Extension(
name="opteryx.compiled.structures.buffers",
sources=[
"opteryx/compiled/structures/buffers.pyx",
"src/cpp/intbuffer.cpp"
],
sources=["opteryx/compiled/structures/buffers.pyx", "src/cpp/intbuffer.cpp"],
include_dirs=include_dirs,
language="c++",
define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")],
Expand Down Expand Up @@ -251,10 +353,7 @@ def rust_build(setup_kwargs: Dict[str, Any]) -> None:
),
Extension(
name="opteryx.compiled.structures.jsonl_decoder",
sources=[
"opteryx/compiled/structures/jsonl_decoder.pyx",
"src/cpp/simd_search.cpp"
],
sources=["opteryx/compiled/structures/jsonl_decoder.pyx", "src/cpp/simd_search.cpp"],
include_dirs=include_dirs + ["third_party/fastfloat/fast_float"],
language="c++",
extra_compile_args=CPP_COMPILE_FLAGS,
Expand Down Expand Up @@ -302,13 +401,16 @@ def rust_build(setup_kwargs: Dict[str, Any]) -> None:
list_ops_file = os.path.join(list_ops_dir, "list_ops.pyx")

# Find all .pyx files in the list_ops directory (excluding list_ops.pyx itself)
pyx_files = sorted([
os.path.basename(f) for f in glob.glob(os.path.join(list_ops_dir, "*.pyx"))
if os.path.basename(f) != "list_ops.pyx"
])
pyx_files = sorted(
[
os.path.basename(f)
for f in glob.glob(os.path.join(list_ops_dir, "*.pyx"))
if os.path.basename(f) != "list_ops.pyx"
]
)

# Generate the list_ops.pyx file with include directives
with open(list_ops_file, 'w', encoding="UTF8") as f:
with open(list_ops_file, "w", encoding="UTF8") as f:
f.write("""# cython: language_level=3
# cython: nonecheck=False
# cython: cdivision=True
Expand All @@ -326,26 +428,31 @@ def rust_build(setup_kwargs: Dict[str, Any]) -> None:
\"\"\"

""")

# Add include directives for each .pyx file
for pyx_file in pyx_files:
f.write(f'include "{pyx_file}"\n')

print(f"\033[38;2;189;147;249mAuto-generated list_ops.pyx with {len(pyx_files)} includes\033[0m")
print(
f"\033[38;2;189;147;249mAuto-generated list_ops.pyx with {len(pyx_files)} includes\033[0m"
)

# Auto-generate joins.pyx to include all individual .pyx files in the folder
# This ensures new files are automatically included when added
joins_dir = "opteryx/compiled/joins"
joins_file = os.path.join(joins_dir, "joins.pyx")

# Find all .pyx files in the joins directory (excluding joins.pyx itself)
joins_pyx_files = sorted([
os.path.basename(f) for f in glob.glob(os.path.join(joins_dir, "*.pyx"))
if os.path.basename(f) != "joins.pyx"
])
joins_pyx_files = sorted(
[
os.path.basename(f)
for f in glob.glob(os.path.join(joins_dir, "*.pyx"))
if os.path.basename(f) != "joins.pyx"
]
)

# Generate the joins.pyx file with include directives
with open(joins_file, 'w', encoding="UTF8") as f:
with open(joins_file, "w", encoding="UTF8") as f:
f.write("""# cython: language_level=3
# cython: nonecheck=False
# cython: cdivision=True
Expand All @@ -363,26 +470,31 @@ def rust_build(setup_kwargs: Dict[str, Any]) -> None:
\"\"\"

""")

# Add include directives for each .pyx file
for pyx_file in joins_pyx_files:
f.write(f'include "{pyx_file}"\n')

print(f"\033[38;2;189;147;249mAuto-generated joins.pyx with {len(joins_pyx_files)} includes\033[0m")
print(
f"\033[38;2;189;147;249mAuto-generated joins.pyx with {len(joins_pyx_files)} includes\033[0m"
)

# Auto-generate functions.pyx to include all individual .pyx files in the folder
# This ensures new files are automatically included when added
functions_dir = "opteryx/compiled/functions"
functions_file = os.path.join(functions_dir, "functions.pyx")

# Find all .pyx files in the functions directory (excluding functions.pyx itself)
functions_pyx_files = sorted([
os.path.basename(f) for f in glob.glob(os.path.join(functions_dir, "*.pyx"))
if os.path.basename(f) != "functions.pyx"
])
functions_pyx_files = sorted(
[
os.path.basename(f)
for f in glob.glob(os.path.join(functions_dir, "*.pyx"))
if os.path.basename(f) != "functions.pyx"
]
)

# Generate the functions.pyx file with include directives
with open(functions_file, 'w', encoding="UTF8") as f:
with open(functions_file, "w", encoding="UTF8") as f:
f.write("""# cython: language_level=3
# cython: nonecheck=False
# cython: cdivision=True
Expand All @@ -400,12 +512,14 @@ def rust_build(setup_kwargs: Dict[str, Any]) -> None:
\"\"\"

""")

# Add include directives for each .pyx file
for pyx_file in functions_pyx_files:
f.write(f'include "{pyx_file}"\n')

print(f"\033[38;2;189;147;249mAuto-generated functions.pyx with {len(functions_pyx_files)} includes\033[0m")
print(
f"\033[38;2;189;147;249mAuto-generated functions.pyx with {len(functions_pyx_files)} includes\033[0m"
)

list_ops_link_args = []
if not is_mac():
Expand All @@ -414,16 +528,10 @@ def rust_build(setup_kwargs: Dict[str, Any]) -> None:
extensions.append(
Extension(
name="opteryx.compiled.list_ops.function_definitions",
sources=[
list_ops_file,
"src/cpp/simd_search.cpp"
],
sources=[list_ops_file, "src/cpp/simd_search.cpp"],
language="c++",
include_dirs=include_dirs + [
"third_party/abseil",
"third_party/apache",
"opteryx/third_party/apache"
],
include_dirs=include_dirs
+ ["third_party/abseil", "third_party/apache", "opteryx/third_party/apache"],
extra_compile_args=CPP_COMPILE_FLAGS,
extra_link_args=list_ops_link_args,
),
Expand All @@ -434,10 +542,7 @@ def rust_build(setup_kwargs: Dict[str, Any]) -> None:
name="opteryx.compiled.joins.join_definitions",
sources=[joins_file],
language="c++",
include_dirs=include_dirs + [
"third_party/abseil",
"third_party/fastfloat/fast_float"
],
include_dirs=include_dirs + ["third_party/abseil", "third_party/fastfloat/fast_float"],
extra_compile_args=CPP_COMPILE_FLAGS,
),
)
Expand All @@ -452,7 +557,6 @@ def rust_build(setup_kwargs: Dict[str, Any]) -> None:
),
)


setup_config = {
"name": LIBRARY,
"version": __version__,
Expand All @@ -472,7 +576,7 @@ def rust_build(setup_kwargs: Dict[str, Any]) -> None:
},
"package_data": {
"": ["*.pyx", "*.pxd"],
}
},
}

rust_build(setup_config)
Expand Down
Loading