Skip to content

Commit cbf9a8f

Browse files
chandrasekhard2recml authors
authored andcommitted
Add DLRM DCN v2 for Inference Benchmarking.
This is the check-in to match the inference implementation. Major features: 1. Added check-pointing during training. 2. Added the option to load checkpoint for evaluation only. 3. Added eval interval and steps. 4. Added bash scripts for training and eval PiperOrigin-RevId: 769699497
1 parent 4441ae5 commit cbf9a8f

File tree

6 files changed

+1251
-0
lines changed

6 files changed

+1251
-0
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#!/bin/bash
2+
3+
4+
export LIBTPU_INIT_ARGS=
5+
export XLA_FLAGS=
6+
7+
export TPU_NAME=<TPU_NAME>
8+
export LEARNING_RATE=0.0034
9+
export BATCH_SIZE=135168
10+
export EMBEDDING_SIZE=128
11+
export MODEL_DIR=/tmp/
12+
export FILE_PATTERN=gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/train-*
13+
export NUM_STEPS=28000
14+
export CHECKPOINT_INTERVAL=1500
15+
export EVAL_INTERVAL=1500
16+
export EVAL_FILE_PATTER=gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/eval-*
17+
export EVAL_STEPS=660
18+
export MODE=eval
19+
export EMBEDDING_THRESHOLD=21000
20+
export LOGGING_INTERVAL=1500
21+
export RESTORE_CHECKPOINT=true
22+
23+
24+
25+
python recml/inference/models/jax/DLRM_DCNv2/dlrm_main.py \
26+
27+
--learning_rate=${LEARNING_RATE} \
28+
--batch_size=${BATCH_SIZE} \
29+
--embedding_size=${EMBEDDING_SIZE} \
30+
--embedding_threshold=${EMBEDDING_THRESHOLD} \
31+
--model_dir=${MODEL_DIR} \
32+
--file_pattern=${FILE_PATTERN} \
33+
--num_steps=${NUM_STEPS} \
34+
--save_checkpoint_interval=${CHECKPOINT_INTERVAL} \
35+
--restore_checkpoint=${RESTORE_CHECKPOINT} \
36+
--eval_interval=${EVAL_INTERVAL} \
37+
--eval_file_pattern=${EVAL_FILE_PATTERN} \
38+
--eval_steps=${EVAL_STEPS} \
39+
--mode=${MODE} \
40+
--logging_interval=${LOGGING_INTERVAL}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#!/bin/bash
2+
3+
4+
export LIBTPU_INIT_ARGS=
5+
export XLA_FLAGS=
6+
7+
export TPU_NAME=<TPU_NAME>
8+
export LEARNING_RATE=0.0034
9+
export BATCH_SIZE=135168
10+
export EMBEDDING_SIZE=128
11+
export MODEL_DIR=/tmp/
12+
export FILE_PATTERN=gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/train-*
13+
export NUM_STEPS=28000
14+
export CHECKPOINT_INTERVAL=1500
15+
export EVAL_INTERVAL=1500
16+
export EVAL_FILE_PATTER=gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/eval-*
17+
export EVAL_STEPS=660
18+
export MODE=eval
19+
export EMBEDDING_THRESHOLD=21000
20+
export LOGGING_INTERVAL=1500
21+
export RESTORE_CHECKPOINT=true
22+
23+
python recml/inference/models/jax/DLRM_DCNv2/dlrm_main.py \
24+
25+
--learning_rate=${LEARNING_RATE} \
26+
--batch_size=${BATCH_SIZE} \
27+
--embedding_size=${EMBEDDING_SIZE} \
28+
--embedding_threshold=${EMBEDDING_THRESHOLD} \
29+
--model_dir=${MODEL_DIR} \
30+
--file_pattern=${FILE_PATTERN} \
31+
--num_steps=${NUM_STEPS} \
32+
--save_checkpoint_interval=${CHECKPOINT_INTERVAL} \
33+
--restore_checkpoint=${RESTORE_CHECKPOINT} \
34+
--eval_interval=${EVAL_INTERVAL} \
35+
--eval_file_pattern=${EVAL_FILE_PATTERN} \
36+
--eval_steps=${EVAL_STEPS} \
37+
--mode=${MODE} \
38+
--logging_interval=${LOGGING_INTERVAL}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
2+
3+
# Running RecML Inference benchmarks
4+
5+
## Setup environment
6+
7+
### Export Env
8+
9+
```
10+
export TPU_NAME=
11+
export QR_NODE_NAME=
12+
export PROJECT=
13+
export ZONE=
14+
export ACCELERATOR_TYPE=
15+
export RUNTIME_VERSION=
16+
```
17+
18+
### Launch a TPU VM
19+
20+
```
21+
gcloud alpha compute tpus queued-resources create ${TPU_NAME} --node-id ${QR_NODE_NAME}$ --project ${PROJECT} --zone ${ZONE} --accelerator-type ${ACCELERATOR_TYPE} --runtime-version ${RUNTIME_VERSION}
22+
```
23+
24+
### Install dependencies
25+
26+
27+
#### Clone the RecML repository
28+
29+
```
30+
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE} --worker=all --command="git clone https://github.com/AI-Hypercomputer/RecML.git"
31+
```
32+
33+
#### Install requirements
34+
35+
```
36+
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE} --worker=all --command="cd RecML && pip install -r requirements.txt"
37+
```
38+
39+
#### Install jax and jaxlib nightly
40+
41+
```
42+
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE} --worker=all --command="pip install -U --pre jax jaxlib libtpu requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --force"
43+
```
44+
45+
#### Install JAX Sparsecore (jax-tpu-embedding)
46+
47+
```
48+
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE} --worker=all --command="pip install -U https://storage.googleapis.com/jax-tpu-embedding-whls/20250604/jax_tpu_embedding-0.1.0.dev20250604-cp310-cp310-manylinux_2_35_x86_64.whl --force"
49+
```
50+
51+
#### Install other dependencies
52+
53+
```
54+
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE} --worker=all --command="pip install -U tensorflow dm-tree flax google-metrax"
55+
```
56+
57+
#### Run workload
58+
59+
Note: Please update the MODEL_NAME & TASK_NAME before running the below command
60+
61+
```
62+
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE} --worker=all --command="TPU_NAME=${TPU_NAME} ./inference/benchmarks/<MODEL_NAME>/<TASK_NAME>"
63+
```

0 commit comments

Comments
 (0)