@@ -17,8 +17,7 @@ RUN chmod 777 -R /tmp && apt-get update && DEBIAN_FRONTEND=noninteractive apt-ge
1717 git \
1818 cmake \
1919 ninja-build \
20- build-essential \
21- ccache && \
20+ build-essential && \
2221 rm -rf /var/lib/apt/lists/*
2322
2423# Install Mambaforge
@@ -44,16 +43,8 @@ WORKDIR /workspace
4443# Install PyTorch with CUDA support
4544RUN pip install torch==2.7.1
4645
47- # Install build dependencies + 构建加速工具
48- RUN pip install --upgrade pip setuptools wheel build scikit-build-core[pyproject] pybind11 ninja psutil
49-
50- # 🚀 设置ccache编译缓存加速
51- ENV CCACHE_DIR=/tmp/ccache \
52- CCACHE_MAXSIZE=2G \
53- CCACHE_COMPRESS=true \
54- CC="ccache gcc" \
55- CXX="ccache g++"
56- RUN ccache --set-config=max_size=2G
46+ # Install build dependencies
47+ RUN pip install --upgrade pip setuptools wheel build scikit-build-core[pyproject] pybind11 ninja
5748
5849# Copy source code to container
5950COPY . .
@@ -68,68 +59,40 @@ RUN python -c "import torch; print(f'PyTorch installed at: {torch.__path__[0]}')
6859ENV FLASH_ATTENTION_FORCE_BUILD=TRUE \
6960 FLASH_ATTENTION_DISABLE_BACKWARD=TRUE \
7061 CUDA_HOME=/usr/local/cuda \
71- CUDA_ROOT=/usr/local/cuda \
72- CCACHE_DISABLE=0
62+ CUDA_ROOT=/usr/local/cuda
7363
7464# 🎯 关键修复:设置 CMAKE_PREFIX_PATH 让 CMake 找到 PyTorch
7565RUN TORCH_CMAKE_PATH=$(python -c "import torch; print(torch.utils.cmake_prefix_path)" ) && \
7666 echo "export CMAKE_PREFIX_PATH=$TORCH_CMAKE_PATH:\$ CMAKE_PREFIX_PATH" >> ~/.bashrc && \
7767 echo "CMAKE_PREFIX_PATH=$TORCH_CMAKE_PATH" >> /etc/environment
7868
79- # 🚀 GitHub Actions优化:智能设置并行度(针对2核7GB限制)
80- RUN python -c "\
81- import os, psutil; \
82- cpu_cores = min(2, os.cpu_count()); \
83- available_memory_gb = min(7, psutil.virtual_memory().available / (1024**3)); \
84- memory_jobs = max(1, int(available_memory_gb / 3)); \
85- optimal_jobs = min(cpu_cores, memory_jobs, 2); \
86- nvcc_threads = optimal_jobs; \
87- print(f'🎯 CI优化: MAX_JOBS={optimal_jobs}, NVCC_THREADS={nvcc_threads}'); \
88- print(f'💾 估算资源: {available_memory_gb:.1f}GB, {cpu_cores}核'); \
89- f = open('/etc/environment', 'a'); \
90- f.write(f'MAX_JOBS={optimal_jobs}\n '); \
91- f.write(f'NVCC_THREADS={nvcc_threads}\n '); \
92- f.close()"
93-
9469# Create output directory
9570RUN mkdir -p /out
9671
97-
98-
9972# Build lightllm-kernel package (main project)
73+ # 🎯 关键:在构建时设置 CMAKE_PREFIX_PATH,让 CMake 找到 PyTorch
10074RUN echo "🔧 Building lightllm-kernel package..." && \
10175 echo "📋 Verifying PyTorch installation..." && \
10276 python -c "import torch; print(f'PyTorch version: {torch.__version__}'); print(f'CUDA available: {torch.cuda.is_available()}'); print(f'CMake prefix path: {torch.utils.cmake_prefix_path}')" && \
103- eval $(cat /etc/environment | xargs -I {} echo export {}) && \
10477 TORCH_CMAKE_PATH=$(python -c "import torch; print(torch.utils.cmake_prefix_path)" ) && \
10578 echo "🔧 Setting CMAKE_PREFIX_PATH to: $TORCH_CMAKE_PATH" && \
106- echo "🚀 Using optimized settings: MAX_JOBS=$MAX_JOBS, NVCC_THREADS=$NVCC_THREADS" && \
10779 CMAKE_PREFIX_PATH="$TORCH_CMAKE_PATH:$CMAKE_PREFIX_PATH" python -m build --wheel --outdir /out/ && \
10880 echo "✅ lightllm-kernel build completed"
10981
110- # Build flash_attn_3 package (hopper) - 源码优化构建
111- RUN echo "🔧 Building flash_attn_3 from source with optimizations ..." && \
82+ # Build flash_attn_3 package (hopper)
83+ RUN echo "🔧 Building flash_attn_3 package ..." && \
11284 cd flash-attention/hopper && \
113- eval $(cat /etc/environment | xargs -I {} echo export {}) && \
114- echo "🚀 Optimized settings: MAX_JOBS=$MAX_JOBS, NVCC_THREADS=$NVCC_THREADS" && \
115- echo "⏰ GitHub Actions: Building within 6h time limit..." && \
116- MAX_JOBS=$MAX_JOBS NVCC_THREADS=$NVCC_THREADS FLASH_ATTN_CUDA_ARCHS=90 python setup.py bdist_wheel && \
85+ MAX_JOBS=2 NVCC_THREADS=2 FLASH_ATTN_CUDA_ARCHS=90 python setup.py bdist_wheel && \
11786 cp dist/*.whl /out/ && \
118- echo "✅ flash_attn_3 optimized source build completed"
119-
120- # 显示编译缓存统计(如果可用)
121- RUN ccache --show-stats 2>/dev/null || echo "💾 ccache stats not available"
87+ echo "✅ flash_attn_3 build completed"
12288
123- # Verify all wheels are built (源码构建验证)
89+ # Verify all wheels are built
12490RUN echo "📦 Final wheel packages:" && \
12591 ls -la /out/ && \
12692 WHEEL_COUNT=$(ls -1 /out/*.whl | wc -l) && \
127- echo "🎯 Total wheels built: $WHEEL_COUNT" && \
93+ echo "Total wheels built: $WHEEL_COUNT" && \
12894 if [ "$WHEEL_COUNT" -ne 2 ]; then \
129- echo "❌ ERROR: Expected 2 wheels (lightllm-kernel + flash_attn_3), found $WHEEL_COUNT" && \
130- echo "📋 Debug info:" && ls -la /out/ && \
131- exit 1; \
95+ echo "❌ Error: Expected 2 wheels, found $WHEEL_COUNT" && exit 1; \
13296 else \
133- echo "🎉 SUCCESS: All wheels built from optimized source compilation!" ; \
134- fi && \
135- echo "🕒 Optimized build completed within GitHub Actions time limit!"
97+ echo "✅ Successfully built all wheel packages" ; \
98+ fi
0 commit comments