Are there any structural code changes I need to make when running my jax/flax model notebook on a GCP TPU instance vs. a local macOS cpu? #11355
Replies: 1 comment 1 reply
-
No, you don't need to make any changes for your code to run. You can be explicit about what platforms/devices to use, but by default JAX will use use the GPU or TPU platform if available. (And it always warns you when no GPU or TPU is found, unless you explicitly request the CPU.) You can see the default device(s) JAX will run on by printing the output of One thing to keep in mind is that Cloud TPUs have 8 cores, and so WDYT? |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I ssh'd into a google cloud tpu vm instance. Do I need to edit my jax/flax code in any way to have the tpu perform the computations (e.g. something like pytorch to device("CUDA")). Or would just running the code on the TPU instance suffice? I guess what I am asking is:
Beta Was this translation helpful? Give feedback.
All reactions