@@ -20,37 +20,36 @@ ARG NVCR_IMAGE_VERSION=25.02-py3
2020# This is based on what is inside the NVCR image already
2121ARG PYTHON_VERSION=3.12
2222
23- # # Base Layer ########################################## ########################
24- FROM nvcr.io/nvidia/pytorch:${NVCR_IMAGE_VERSION} AS dev
23+ # ####################### BUILDER ########################
24+ FROM nvcr.io/nvidia/pytorch:${NVCR_IMAGE_VERSION} AS builder
2525
2626ARG USER=root
2727ARG USER_UID=0
2828ARG WORKDIR=/app
2929ARG SOURCE_DIR=${WORKDIR}/fms-hf-tuning
3030
3131ARG ENABLE_FMS_ACCELERATION=true
32- ARG ENABLE_AIM=true
33- ARG ENABLE_MLFLOW=true
34- ARG ENABLE_SCANNER=true
32+ ARG ENABLE_AIM=false
33+ ARG ENABLE_MLFLOW=false
34+ ARG ENABLE_SCANNER=false
3535ARG ENABLE_CLEARML=true
3636ARG ENABLE_TRITON_KERNELS=true
37- ARG ENABLE_MAMBA_SUPPORT=true
3837
3938# Ensures to always build mamba_ssm from source
4039ENV PIP_NO_BINARY=mamba-ssm,mamba_ssm
4140
42- RUN python -m pip install --upgrade pip
43-
4441# upgrade torch as the base layer contains only torch 2.7
45- RUN pip install --upgrade --force-reinstall torch torchaudio torchvision --index-url https://download.pytorch.org/whl/cu128
42+ RUN python -m pip install --upgrade pip && \
43+ pip install --upgrade setuptools && \
44+ pip install --upgrade --force-reinstall torch torchaudio torchvision --index-url https://download.pytorch.org/whl/cu128
4645
4746# Install main package + flash attention
4847COPY . ${SOURCE_DIR}
4948RUN cd ${SOURCE_DIR}
5049
51- RUN pip install --upgrade pip setuptools wheel
52- RUN pip install --no-cache-dir ${SOURCE_DIR}
53- RUN pip install --user --no-build-isolation ${SOURCE_DIR}[flash-attn ]
50+ RUN pip install --no-cache-dir ${SOURCE_DIR} && \
51+ pip install --user -- no-build-isolation ${SOURCE_DIR}[flash-attn] && \
52+ pip install --no-cache-dir --no-build-isolation ${SOURCE_DIR}[mamba ]
5453
5554# Optional extras
5655RUN if [[ "${ENABLE_FMS_ACCELERATION}" == "true" ]]; then \
@@ -62,6 +61,12 @@ RUN if [[ "${ENABLE_FMS_ACCELERATION}" == "true" ]]; then \
6261 python -m fms_acceleration.cli install fms_acceleration_odm; \
6362 fi
6463
64+ RUN if [[ "${ENABLE_TRITON_KERNELS}" == "true" ]]; then \
65+ pip install --no-cache-dir "git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels" ; \
66+ fi
67+ RUN if [[ "${ENABLE_CLEARML}" == "true" ]]; then \
68+ pip install --no-cache-dir ${SOURCE_DIR}[clearml]; \
69+ fi
6570RUN if [[ "${ENABLE_AIM}" == "true" ]]; then \
6671 pip install --no-cache-dir ${SOURCE_DIR}[aim]; \
6772 fi
@@ -71,15 +76,22 @@ RUN if [[ "${ENABLE_MLFLOW}" == "true" ]]; then \
7176RUN if [[ "${ENABLE_SCANNER}" == "true" ]]; then \
7277 pip install --no-cache-dir ${SOURCE_DIR}[scanner-dev]; \
7378 fi
74- RUN if [[ "${ENABLE_CLEARML}" == "true" ]]; then \
75- pip install --no-cache-dir ${SOURCE_DIR}[clearml]; \
76- fi
77- RUN if [[ "${ENABLE_MAMBA_SUPPORT}" == "true" ]]; then \
78- pip install --no-cache-dir ${SOURCE_DIR}[mamba]; \
79- fi
80- RUN if [[ "${ENABLE_TRITON_KERNELS}" == "true" ]]; then \
81- pip install --no-cache-dir "git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels" ; \
82- fi
79+
80+ # cleanup
81+ RUN rm -rf /root/.cache /tmp/* /opt/pytorch
82+
83+ # ####################### RUNTIME ########################
84+ FROM nvcr.io/nvidia/pytorch:${NVCR_IMAGE_VERSION}
85+
86+ WORKDIR ${WORKDIR}
87+
88+ # Copy only Python site-packages + app
89+ COPY --from=builder /usr/local/lib/python3.12/dist-packages \
90+ /usr/local/lib/python3.12/dist-packages
91+ COPY --from=builder ${SOURCE_DIR} ${SOURCE_DIR}
92+
93+ # Runtime cleanup
94+ RUN rm -rf /opt/pytorch /root/.cache /tmp/*
8395
8496RUN chmod -R g+rwX $WORKDIR /tmp
8597RUN mkdir -p /.cache && chmod -R 777 /.cache
0 commit comments