Skip to content

Commit 158327f

Browse files
Indraneil PaulIndraneil Paul
authored andcommitted
Allow faster prototyping and add Cuda 11 dockerfile
1 parent 85937d9 commit 158327f

File tree

7 files changed

+267
-41
lines changed

7 files changed

+267
-41
lines changed
File renamed without changes.

Generate.Dockerfile

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04
2+
3+
SHELL ["/bin/bash", "-c"]
4+
5+
# Setup Environment Variables
6+
ENV CUDA_HOME=/usr/local/cuda \
7+
PYTHONUNBUFFERED=1 \
8+
TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
9+
10+
# Setup System Utilities
11+
RUN apt-get update --yes --quiet \
12+
&& apt-get upgrade --yes --quiet \
13+
&& DEBIAN_FRONTEND=noninteractive apt-get install --yes --quiet --no-install-recommends \
14+
apt-utils \
15+
autoconf \
16+
automake \
17+
bc \
18+
build-essential \
19+
ca-certificates \
20+
check \
21+
cmake \
22+
curl \
23+
dmidecode \
24+
emacs \
25+
g++\
26+
gcc \
27+
git \
28+
iproute2 \
29+
jq \
30+
kmod \
31+
libaio-dev \
32+
libcurl4-openssl-dev \
33+
libgl1-mesa-glx \
34+
libglib2.0-0 \
35+
libgomp1 \
36+
libibverbs-dev \
37+
libnuma-dev \
38+
libnuma1 \
39+
libomp-dev \
40+
libsm6 \
41+
libssl-dev \
42+
libsubunit-dev \
43+
libsubunit0 \
44+
libtool \
45+
libxext6 \
46+
libxrender-dev \
47+
make \
48+
moreutils \
49+
net-tools \
50+
ninja-build \
51+
openssh-client \
52+
openssh-server \
53+
openssl \
54+
pkg-config \
55+
python3-dev \
56+
software-properties-common \
57+
sudo \
58+
unzip \
59+
util-linux \
60+
vim \
61+
wget \
62+
zlib1g-dev \
63+
&& apt-get autoremove \
64+
&& apt-get clean \
65+
&& rm -rf /var/lib/apt/lists/
66+
67+
# Setup base Python to bootstrap Mamba
68+
RUN add-apt-repository --yes ppa:deadsnakes/ppa \
69+
&& apt-get update --yes --quiet
70+
RUN DEBIAN_FRONTEND=noninteractive apt-get install --yes --quiet --no-install-recommends \
71+
python3.11 \
72+
python3.11-dev \
73+
python3.11-distutils \
74+
python3.11-lib2to3 \
75+
python3.11-gdbm \
76+
python3.11-tk \
77+
pip
78+
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 999 \
79+
&& update-alternatives --config python3 \
80+
&& ln -s /usr/bin/python3 /usr/bin/python
81+
RUN pip install --upgrade pip
82+
83+
# Setup optimized Mamba environment with required PyTorch dependencies
84+
RUN wget -O /tmp/Miniforge.sh https://github.com/conda-forge/miniforge/releases/download/24.3.0-0/Mambaforge-24.3.0-0-Linux-x86_64.sh \
85+
&& bash /tmp/Miniforge.sh -b -p /Miniforge \
86+
&& source /Miniforge/etc/profile.d/conda.sh \
87+
&& source /Miniforge/etc/profile.d/mamba.sh \
88+
&& mamba update -y -q -n base -c defaults mamba \
89+
&& mamba create -y -q -n Code-Eval python=3.11 setuptools=69.5.1 \
90+
&& mamba activate Code-Eval \
91+
&& mamba install -y -q -c conda-forge \
92+
charset-normalizer \
93+
gputil \
94+
ipython \
95+
numpy \
96+
pandas \
97+
scikit-learn \
98+
wandb \
99+
&& mamba install -y -q -c intel \
100+
"mkl==2023" \
101+
"mkl-static==2023" \
102+
"mkl-include==2023" \
103+
&& mamba install -y -q -c pytorch magma-cuda118 \
104+
&& mamba clean -a -f -y
105+
106+
# Install VLLM precompiled with appropriate CUDA and ensure PyTorch is installed form the same version channel
107+
RUN source /Miniforge/etc/profile.d/conda.sh \
108+
&& source /Miniforge/etc/profile.d/mamba.sh \
109+
&& mamba activate Code-Eval \
110+
&& pip install https://github.com/vllm-project/vllm/releases/download/v0.4.0/vllm-0.4.0+cu118-cp311-cp311-manylinux1_x86_64.whl \
111+
--extra-index-url https://download.pytorch.org/whl/cu118
112+
113+
# Install Flash Attention
114+
RUN source /Miniforge/etc/profile.d/conda.sh \
115+
&& source /Miniforge/etc/profile.d/mamba.sh \
116+
&& mamba activate Code-Eval \
117+
&& export MAX_JOBS=$(($(nproc) - 2)) \
118+
&& pip install --no-cache-dir ninja packaging psutil \
119+
&& pip install flash-attn==2.5.8 --no-build-isolation
120+
121+
# Add a new user "wildcodeuser"
122+
RUN adduser --disabled-password --gecos "" wildcodeuser
123+
124+
# Acquire benchmark code to local
125+
RUN git clone https://github.com/NVIDIA/apex /wildcode
126+
127+
RUN chown -R wildcodeuser:wildcodeuser /wildcode
128+
USER wildcodeuser
129+
130+
# Install Code-Eval and pre-load the dataset
131+
RUN source /Miniforge/etc/profile.d/conda.sh \
132+
&& source /Miniforge/etc/profile.d/mamba.sh \
133+
&& mamba activate Code-Eval \
134+
&& pip install wild-code --upgrade \
135+
&& python -c "from wildcode.data import get_wildcodebench; get_wildcodebench()"
136+
137+
WORKDIR /wildcode
138+
139+
# Declare an argument for the huggingface token
140+
ARG HF_TOKEN
141+
RUN if [[ -n "$HF_TOKEN" ]] ; then /Miniforge/envs/Code-Eval/bin/huggingface-cli login --token $HF_TOKEN ; \
142+
else echo "No HuggingFace token specified. Access to gated or private models will be unavailable." ; fi
143+
144+
ENTRYPOINT ["/Miniforge/envs/Code-Eval/bin/python", "-m", "wildcode.generate"]

Requirements/requirements-eval.txt

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
beautifulsoup4==4.8.2
2+
blake3==0.4.1
3+
chardet==5.2.0
4+
cryptography==38.0.0
5+
datetime==5.5
6+
Django==4.2.7
7+
dnspython==2.6.1
8+
docxtpl==0.11.5
9+
Faker==20.1.0
10+
flask_login==0.6.3
11+
flask_restful==0.3.10
12+
flask_wtf==1.2.1
13+
Flask-Mail==0.9.1
14+
flask==3.0.3
15+
folium==0.16.0
16+
gensim==4.3.2
17+
geopandas==0.13.2
18+
geopy==2.4.1
19+
holidays==0.29
20+
keras==2.11.0
21+
Levenshtein==0.25.0
22+
librosa==0.10.1
23+
lxml==4.9.3
24+
matplotlib==3.7.0
25+
mechanize==0.4.9
26+
natsort==7.1.1
27+
networkx==2.6.3
28+
nltk==3.8
29+
numba==0.55.0
30+
numpy==1.21.2
31+
opencv-python-headless==4.9.0.80
32+
openpyxl==3.1.2
33+
pandas==2.0.3
34+
Pillow==10.3.0
35+
prettytable==3.10.0
36+
psutil==5.9.5
37+
pycryptodome==3.14.1
38+
pyfakefs==5.4.1
39+
pyquery==1.4.3
40+
pytesseract==0.3.10
41+
pytest==8.2.0
42+
python_http_client==3.3.7
43+
python-dateutil==2.9.0
44+
python-docx==1.1.0
45+
python-Levenshtein-wheels
46+
pytz==2023.3.post1
47+
PyYAML==6.0.1
48+
requests_mock==1.11.0
49+
requests==2.31.0
50+
Requests==2.31.0
51+
rsa==4.9
52+
scikit-image==0.18.0
53+
scikit-learn==1.3.1
54+
scipy==1.7.2
55+
seaborn==0.13.2
56+
selenium==4.15.
57+
sendgrid==6.11.0
58+
shapely==2.0.4
59+
soundfile==0.12.1
60+
statsmodels==0.14.0
61+
statsmodels==0.14.0
62+
sympy==1.12
63+
tensorflow==2.11.1
64+
textblob==0.18.0
65+
texttable==1.7.0
66+
Werkzeug==3.0.1
67+
wikipedia==1.4.0
68+
wordcloud==1.9.3
69+
wordninja==2.0.0
70+
WTForms==3.1.2
71+
xlrd==2.0.1
72+
xlrd==2.0.1
73+
xlwt==1.3.0
74+
xmltodict==0.13.0

Requirements/requirements.txt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
accelerate>=0.30.1
2+
anthropic>=0.26.1
3+
appdirs>=1.4.4
4+
fire>=0.6.0
5+
google-generativeai>=0.5.4
6+
mistralai>=0.2.0
7+
multipledispatch>=0.6.0
8+
numpy>=1.19.5
9+
openai>=1.11.1
10+
Pympler>=1.0.1
11+
rich>=12.3.0
12+
stop-sequencer>=1.2.3
13+
tempdir>=0.7.1
14+
termcolor>=2.0.0
15+
tqdm>=4.56.0
16+
tree_sitter_languages>=1.10.2
17+
tree-sitter==0.21.3
18+
wget>=3.2e

requirements.txt

Lines changed: 0 additions & 18 deletions
This file was deleted.

setup.cfg

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,25 @@ packages = find:
1717
python_requires = >=3.8
1818
dependency_links =
1919
install_requires =
20-
wget>=3.2
21-
tempdir>=0.7.1
22-
multipledispatch>=0.6.0
20+
accelerate>=0.30.1
21+
anthropic>=0.26.1
2322
appdirs>=1.4.4
24-
numpy>=1.19.5
25-
tqdm>=4.56.0
26-
termcolor>=2.0.0
2723
fire>=0.6.0
24+
google-generativeai>=0.5.4
25+
mistralai>=0.2.0
26+
multipledispatch>=0.6.0
27+
numpy>=1.19.5
2828
openai>=1.11.1
29+
Pympler>=1.0.1
2930
rich>=12.3.0
31+
stop-sequencer>=1.2.3
32+
tempdir>=0.7.1
33+
termcolor>=2.0.0
34+
tqdm>=4.56.0
3035
tree_sitter_languages>=1.10.2
3136
tree-sitter==0.21.3
32-
Pympler>=1.0.1
33-
accelerate
34-
vllm
35-
anthropic
36-
mistralai
37-
stop-sequencer
38-
google-generativeai
39-
37+
wget>=3.2
38+
4039
[options.entry_points]
4140
console_scripts =
4241
wildcode.evaluate = wildcode.evaluate:main

wildcode/generate.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import os
22
import json
33
import argparse
4-
from os import PathLike
5-
from typing import List
64

75
from wildcode.model import DecoderBase, make_model
86
from rich.progress import (
@@ -24,6 +22,7 @@ def codegen(
2422
n_samples=1,
2523
id_range=None,
2624
resume=True,
25+
subsample_size=None,
2726
):
2827
with Progress(
2928
TextColumn(f"{dataset} •" + "[progress.percentage]{task.percentage:>3.0f}%"),
@@ -36,10 +35,13 @@ def codegen(
3635
from wildcode.data import get_wildcodebench, write_jsonl
3736

3837
dataset = get_wildcodebench()
38+
if subsample_size:
39+
if subsample_size < len(dataset):
40+
dataset = dataset[:subsample_size]
3941

4042
if model.is_direct_completion() and nl2code:
4143
raise Exception("Base model does not support direct completion for NL2Code tasks")
42-
44+
4345
# create save_path if it doesn't exist, e.g., a/b.jsonl
4446
dirname = os.path.dirname(save_path)
4547
if not os.path.exists(dirname) and dirname != "":
@@ -53,7 +55,7 @@ def codegen(
5355
continue
5456

5557
p_name = task_id.replace("/", "_")
56-
58+
5759
# read the existing file if save_path exists
5860
if os.path.exists(save_path):
5961
with open(save_path, "r") as f:
@@ -103,12 +105,14 @@ def codegen(
103105
print(f"Generated {len(samples)} samples")
104106
write_jsonl(save_path, samples, append=True)
105107
sidx += len(outputs)
106-
108+
107109

108110
def main():
109111
parser = argparse.ArgumentParser()
110112
parser.add_argument("--model", required=True, type=str)
111113
parser.add_argument("--dataset", required=True, type=str)
114+
parser.add_argument("--save_path", default=None, type=str)
115+
parser.add_argument("--subsample_size", default=None, type=int)
112116
parser.add_argument("--nl2code", action='store_true')
113117
parser.add_argument("--bs", default=1, type=int)
114118
parser.add_argument("--n_samples", default=1, type=int)
@@ -121,8 +125,8 @@ def main():
121125
parser.add_argument("--base_url", default=None, type=str)
122126
parser.add_argument("--tp", default=1, type=int)
123127
args = parser.parse_args()
124-
125-
128+
129+
126130
assert args.dataset in ["wildcodebench"], f"Invalid dataset {args.dataset}"
127131
assert args.backend in ["vllm", "hf", "openai", "mistral", "anthropic", "google"]
128132

@@ -153,8 +157,12 @@ def main():
153157
task = "nl2c"
154158
else:
155159
task = "c2c"
156-
save_path = args.model.replace("/", "--") + f"--{args.dataset}-{task}--{args.backend}-{args.temperature}-{args.n_samples}.jsonl"
157-
160+
161+
if not args.save_path:
162+
save_path = args.model.replace("/", "--") + f"--{args.dataset}-{task}--{args.backend}-{args.temperature}-{args.n_samples}.jsonl"
163+
else:
164+
save_path = args.save_path
165+
158166
codegen(
159167
model=model_runner,
160168
save_path=save_path,
@@ -165,6 +173,7 @@ def main():
165173
n_samples=args.n_samples,
166174
resume=args.resume,
167175
id_range=args.id_range,
176+
subsample_size=args.subsample_size,
168177
)
169178

170179

0 commit comments

Comments
 (0)