diff --git a/.github/workflows/rocm-perf.yml b/.github/workflows/rocm-perf.yml new file mode 100644 index 000000000000..fe024ff87c11 --- /dev/null +++ b/.github/workflows/rocm-perf.yml @@ -0,0 +1,111 @@ +name: ROCm DLM Performance Evaluations + +on: + push: + +jobs: + build-and-test-jax: + runs-on: mi-250 + strategy: + matrix: + python: ["3.10"] + rocm: ["6.3.4"] + + env: + BASE_IMAGE: "ubuntu:22.04" + PYTHON_VERSION: ${{ matrix.python }} + ROCM_VERSION: ${{ matrix.rocm }} + TEST_IMAGE: ubuntu-jax-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }} + WORKSPACE_DIR: jax_rocm_perf_${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }} + + steps: + - name: Clean up old workdirs + run: | + docker run -v "$(pwd):/jax" ubuntu bash -c "chown -R $UID /jax/jax_rocm_perf_* || true" + rm -rf jax_rocm_perf_* || true + ls -l + + - name: Print system info + run: | + whoami + printenv + df -h + rocm-smi || true + + - name: Checkout JAX source + uses: actions/checkout@v4 + with: + path: ${{ env.WORKSPACE_DIR }} + + - name: Build JAX Docker Image + run: | + cd $WORKSPACE_DIR + python3 build/rocm/ci_build \ + --rocm-version "$ROCM_VERSION" \ + --base-docker "$BASE_IMAGE" \ + --python-versions "$PYTHON_VERSION" \ + --compiler=clang \ + dist_docker \ + --image-tag "$TEST_IMAGE" + + - name: Checkout MaxText source + uses: actions/checkout@v4 + with: + repository: ROCm/maxtext + ref: rv_jax + path: ${{ env.WORKSPACE_DIR }}/maxtext + + - name: Launch container + run: | + docker run -d --name maxtext_container \ + --network=host \ + --device=/dev/kfd \ + --device=/dev/dri \ + --ipc=host \ + --shm-size=64G \ + --group-add=video \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + -v "$(pwd)/${{ env.WORKSPACE_DIR }}/maxtext:/maxtext" \ + -w /maxtext \ + "$TEST_IMAGE" \ + tail -f /dev/null + + - name: Run MaxText training and save logs + run: | + docker exec maxtext_container bash -c "pip install -r requirements.txt" + for config in \ + MaxText/configs/models/gpu/llama2_7b_rocm.yml \ + MaxText/configs/models/gpu/gemma_2b_rocm.yml \ + MaxText/configs/models/gpu/gpt3_6b_rocm.yml \ + MaxText/configs/models/gpu/mixtral_8x1b_rocm.yml; do + model_name=$(basename "$config" _rocm.yml) + echo "Running $model_name" + if [[ "$model_name" == "mixtral_8x1b" ]]; then + docker exec maxtext_container bash -c "export XLA_PYTHON_CLIENT_MEM_FRACTION=0.95 && python3 -m MaxText.train $config" \ + | tee logs_${model_name}.log + else + docker exec maxtext_container bash -c "python3 -m MaxText.train $config" \ + | tee logs_${model_name}.log + fi + done + + - name: Analyze logs to compute median step time + run: | + pip install numpy + python3 ${{ env.WORKSPACE_DIR }}/build/rocm/analyze_maxtext_logs.py + cat summary.json + + - name: Upload logs and summary + uses: actions/upload-artifact@v4 + with: + name: training-results + path: | + logs_*.log + summary.json + + - name: Cleanup container + if: always() + run: | + docker stop maxtext_container || true + docker rm maxtext_container || true diff --git a/build/rocm/analyze_maxtext_logs.py b/build/rocm/analyze_maxtext_logs.py new file mode 100644 index 000000000000..ce08edd44e72 --- /dev/null +++ b/build/rocm/analyze_maxtext_logs.py @@ -0,0 +1,28 @@ +import json, re, glob +import numpy as np + +summary = {} +for log in glob.glob("logs_*.log"): + model = log.replace("logs_", "").replace(".log", "") + times = [] + with open(log) as f: + for line in f: + m = re.search(r"completed step: \d+, seconds: ([\d.]+)", line) + if m: + times.append(float(m.group(1))) + if times: + times_np = np.array(times) + step_info = [{"step": n, "time": t} for n, t in enumerate(times)] + summary[model] = { + "steps": step_info, + "min_step_time": round(float(np.min(times_np)), 3), + "q25_step_time": round(float(np.percentile(times_np, 25)), 3), + "median_step_time": round(float(np.median(times_np)), 3), + "mean_step_time": round(float(np.mean(times_np)), 3), + "q75_step_time": round(float(np.percentile(times_np, 75)), 3), + "max_step_time": round(float(np.max(times_np)), 3), + "steps_counted": len(times) + } + +with open("summary.json", "w") as f: + json.dump(summary, f, indent=2) \ No newline at end of file