File tree Expand file tree Collapse file tree 1 file changed +14
-0
lines changed Expand file tree Collapse file tree 1 file changed +14
-0
lines changed Original file line number Diff line number Diff line change @@ -69,6 +69,20 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta
6969* ** --xla_gpu_enable_latency_hiding_scheduler** This flag enables latency hiding
7070 schedulers to overlap asynchronous communication with computation efficiently.
7171 The default value is False.
72+ * ** --xla_gpu_memory_limit_slop_factor** This flag serves as a multiplier applied
73+ to the total available memory, creating a threshold that guides the Latency Hiding
74+ Scheduler (LHS) in balancing memory reduction and latency hiding optimizations.
75+ The default value is 95.
76+
77+ This factor effectively establishes a memory limit for compiler passes, determining
78+ when the scheduler should prioritize:
79+ 1 . Memory reduction: When memory usage approaches or exceeds the calculated threshold.
80+ 2 . Latency hiding: When memory usage is below the threshold, allowing for more
81+ aggressive optimizations that may temporarily increase memory usage but improve
82+ overall performance.
83+
84+ By adjusting this factor, users can fine-tune the trade-off between memory efficiency
85+ and performance optimizations.
7286* ** --xla_gpu_enable_pipelined_collectives** When using pipeline parallelism,
7387 this flag enables overlapping the (i+1)-th layer weight ` AllGather ` with the
7488 i-th layer computation. It also enables overlapping (i+1)-th layer
You can’t perform that action at this time.
0 commit comments