@@ -16,8 +16,8 @@ use std::ffi::OsString;
16
16
use subprocess:: { Popen , PopenConfig , PopenError , Redirection } ;
17
17
use tracing:: info;
18
18
19
- // For now this will be disabled by default, more testing is needed
20
- const DEFAULT_MAX_SPLIT_SIZE_MB : & ' static str = "none " ;
19
+ // In most cases this gives the best performance for inferencing
20
+ const DEFAULT_PYTORCH_CUDA_ALLOC_CONF : & ' static str = "expandable_segments:True " ;
21
21
22
22
/// App Configuration
23
23
#[ derive( Parser , Debug , Clone ) ]
@@ -108,18 +108,16 @@ fn main() -> ExitCode {
108
108
& args. model_name , args. revision . as_deref ( )
109
109
) . expect ( "Could not find tokenizer for model" ) ;
110
110
111
- // Set max_split_size to default value if PYTORCH_CUDA_ALLOC_CONF is not set,
112
- // or unset it if PYTORCH_CUDA_ALLOC_CONF is set but empty
111
+ // Set PYTORCH_CUDA_ALLOC_CONF to default value if it's not set in the environment
113
112
let cuda_alloc_conf = match env:: var ( "PYTORCH_CUDA_ALLOC_CONF" ) {
114
- Err ( VarError :: NotPresent ) if DEFAULT_MAX_SPLIT_SIZE_MB == "none " => None ,
113
+ Err ( VarError :: NotPresent ) if DEFAULT_PYTORCH_CUDA_ALLOC_CONF == "" => None ,
115
114
Err ( VarError :: NotPresent ) => {
116
- let alloc_conf = format ! ( "max_split_size_mb:{}" , DEFAULT_MAX_SPLIT_SIZE_MB ) ;
117
- info ! ( "Setting PYTORCH_CUDA_ALLOC_CONF to default value: {alloc_conf}" ) ;
118
- Some ( alloc_conf)
115
+ info ! ( "Setting PYTORCH_CUDA_ALLOC_CONF to default value: {DEFAULT_PYTORCH_CUDA_ALLOC_CONF}" ) ;
116
+ Some ( DEFAULT_PYTORCH_CUDA_ALLOC_CONF )
119
117
} ,
120
118
Ok ( alloc_conf) if alloc_conf. trim ( ) . is_empty ( ) => {
121
119
info ! ( "PYTORCH_CUDA_ALLOC_CONF is unset" ) ;
122
- Some ( String :: new ( ) ) // This means remove it from the env
120
+ Some ( "" ) // This means remove it from the env
123
121
} ,
124
122
Ok ( alloc_conf) => {
125
123
info ! ( "PYTORCH_CUDA_ALLOC_CONF is set to: {alloc_conf}" ) ;
@@ -406,7 +404,7 @@ fn shard_manager(
406
404
max_batch_weight : Option < usize > ,
407
405
uds_path : String ,
408
406
cuda_process_memory_fraction : f32 ,
409
- cuda_alloc_conf : Option < String > ,
407
+ cuda_alloc_conf : Option < & str > ,
410
408
rank : usize ,
411
409
world_size : usize ,
412
410
master_addr : String ,
0 commit comments