The memory consumption is almost same #15231
Unanswered
DaShenZi721
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I have recently been testing the memory usage of several baselines for text tasks on the Long-Range Arena benchmark. However, I have found that regardless of which baseline I use or the batch size, the memory usage is almost the same(around 160MB), as shown in the figure below.

The code can be found:
https://github.com/google-research/long-range-arena/blob/main/lra_benchmarks/text_classification/train.py
I am using the jax-smi, which actually enables an additional thread that runs the
jax.profiler.save_device_memory_profile()
command every second. I added the following two lines of code in the main function.My jax environment is as follows
Beta Was this translation helpful? Give feedback.
All reactions