Skip to content

Commit 8df2766

Browse files
Ruturaj4Github Actions
authored andcommitted
Add argument to override base docker in dockerfile
1 parent 6763fcf commit 8df2766

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

build/rocm/Dockerfile.ms

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
################################################################################
2-
FROM ubuntu:20.04 AS rocm_base
2+
ARG BASE_DOCKER=ubuntu:22.04
3+
FROM $BASE_DOCKER AS rocm_base
34
################################################################################
45

56
RUN --mount=type=cache,target=/var/cache/apt \

build/rocm/ci_build

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def _fetch_jax_metadata(xla_path):
143143

144144
def dist_docker(
145145
rocm_version,
146+
base_docker,
146147
python_versions,
147148
xla_path,
148149
rocm_build_job="",
@@ -168,6 +169,7 @@ def dist_docker(
168169
"--build-arg=ROCM_VERSION=%s" % rocm_version,
169170
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
170171
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
172+
"--build-arg=BASE_DOCKER=%s" % base_docker,
171173
"--build-arg=PYTHON_VERSION=%s" % python_version,
172174
"--build-arg=JAX_VERSION=%(jax_version)s" % md,
173175
"--build-arg=JAX_COMMIT=%(jax_commit)s" % md,
@@ -231,6 +233,12 @@ def test(image_name):
231233

232234
def parse_args():
233235
p = argparse.ArgumentParser()
236+
p.add_argument(
237+
"--base-docker",
238+
default="",
239+
help="Argument to override base docker in dockerfile",
240+
)
241+
234242
p.add_argument(
235243
"--python-versions",
236244
type=lambda x: x.split(","),
@@ -308,6 +316,7 @@ def main():
308316
)
309317
dist_docker(
310318
args.rocm_version,
319+
args.base_docker,
311320
args.python_versions,
312321
args.xla_source_dir,
313322
rocm_build_job=args.rocm_build_job,

build/rocm/ci_build.sh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ PYTHON_VERSION="3.10"
4848
ROCM_VERSION="6.1.3"
4949
ROCM_BUILD_JOB=""
5050
ROCM_BUILD_NUM=""
51-
BASE_DOCKER="ubuntu:20.04"
51+
BASE_DOCKER="ubuntu:22.04"
5252
CUSTOM_INSTALL=""
5353
JAX_USE_CLANG=""
5454
POSITIONAL_ARGS=()
@@ -90,6 +90,10 @@ while [[ $# -gt 0 ]]; do
9090
ROCM_BUILD_NUM="$2"
9191
shift 2
9292
;;
93+
--base_docker)
94+
BASE_DOCKER="$2"
95+
shift 2
96+
;;
9397
--use_clang)
9498
JAX_USE_CLANG="$2"
9599
shift 2
@@ -154,6 +158,7 @@ fi
154158
# which is the ROCm image that is shipped for users to use (i.e. distributable).
155159
./build/rocm/ci_build \
156160
--rocm-version $ROCM_VERSION \
161+
--base-docker $BASE_DOCKER \
157162
--python-versions $PYTHON_VERSION \
158163
--xla-source-dir=$XLA_CLONE_DIR \
159164
--rocm-build-job=$ROCM_BUILD_JOB \

0 commit comments

Comments
 (0)