@@ -174,7 +174,33 @@ Run the following command to verify that ROCm JAX is installed correctly:
174174
175175Follow these steps to build JAX with ROCm support from source:
176176
177- ### Step 1: Clone the Repository
177+ ### Step 1: Install ROCm
178+
179+ Please follow [ ROCm installation guide] ( https://rocm.docs.amd.com/en/latest/deploy/linux/quick_start.html ) to install ROCm on your system.
180+
181+ Once installed, verify ROCm installation using:
182+
183+ ``` Bash
184+ > rocm-smi
185+
186+ ========================================== ROCm System Management Interface ==========================================
187+ ==================================================== Concise Info ====================================================
188+ Device [Model : Revision] Temp Power Partitions SCLK MCLK Fan Perf PwrCap VRAM% GPU%
189+ Name (20 chars) (Junction) (Socket) (Mem, Compute)
190+ ======================================================================================================================
191+ 0 [0x74a1 : 0x00] 50.0°C 170.0W NPS1, SPX 131Mhz 900Mhz 0% auto 750.0W 0% 0%
192+ AMD Instinct MI300X
193+ 1 [0x74a1 : 0x00] 51.0°C 176.0W NPS1, SPX 132Mhz 900Mhz 0% auto 750.0W 0% 0%
194+ AMD Instinct MI300X
195+ 2 [0x74a1 : 0x00] 50.0°C 177.0W NPS1, SPX 132Mhz 900Mhz 0% auto 750.0W 0% 0%
196+ AMD Instinct MI300X
197+ 3 [0x74a1 : 0x00] 53.0°C 176.0W NPS1, SPX 132Mhz 900Mhz 0% auto 750.0W 0% 0%
198+ AMD Instinct MI300X
199+ ======================================================================================================================
200+ ================================================ End of ROCm SMI Log =================================================
201+ ```
202+
203+ ### Step 2: Clone the Repository
178204
179205Clone the ROCm-specific fork of JAX for the desired branch:
180206
@@ -183,13 +209,15 @@ Clone the ROCm-specific fork of JAX for the desired branch:
183209> cd jax
184210```
185211
186- ### Step 2 : Build the Wheels
212+ ### Step 3 : Build the Wheels
187213
188214Run the following command to build the necessary wheels:
189215
190216``` Bash
191- > python3 ./build/build.py build --wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt \
192- --rocm_version=60 --rocm_path=/opt/rocm-[version]
217+ > python3 ./build/build.py build \
218+ --wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt \
219+ --rocm_path=/opt/rocm-[version] \
220+ --clang_path=/opt/rocm-[version]/lib/llvm/bin/clang
193221```
194222
195223This will generate three wheels in the ` dist/ ` directory:
@@ -198,10 +226,10 @@ This will generate three wheels in the `dist/` directory:
198226* jax-rocm-plugin (ROCm-specific plugin)
199227* jax-rocm-pjrt (ROCm-specific runtime)
200228
201- ### Step 3 : Then install custom JAX using:
229+ ### Step 4 : Then install custom JAX using:
202230
203231``` Bash
204- > python3 setup.py develop --user && pip3 -m pip install dist/* .whl
232+ > python3 setup.py develop --user && python3 -m pip install dist/* .whl
205233```
206234
207235### Simplified Build Script
0 commit comments