@@ -7,7 +7,7 @@ use std::{
7
7
io,
8
8
io:: { BufRead , BufReader , ErrorKind , Write } ,
9
9
os:: unix:: process:: { CommandExt , ExitStatusExt } ,
10
- path:: Path ,
10
+ path:: { Path , PathBuf } ,
11
11
process:: { Command , ExitCode , ExitStatus , Stdio } ,
12
12
sync:: {
13
13
atomic:: { AtomicBool , Ordering } ,
@@ -25,7 +25,8 @@ use nix::{
25
25
sys:: signal:: { self , Signal } ,
26
26
unistd:: Pid ,
27
27
} ;
28
- use tracing:: { info, warn} ;
28
+ use tracing:: { error, info, warn} ;
29
+ use uuid:: Uuid ;
29
30
30
31
// In most cases this gives the best performance for inferencing
31
32
const DEFAULT_PYTORCH_CUDA_ALLOC_CONF : & str = "expandable_segments:True" ;
@@ -135,13 +136,27 @@ fn main() -> ExitCode {
135
136
// Determine number of shards based on command line arg and env vars
136
137
let num_shard = find_num_shards ( args. num_shard ) ;
137
138
138
- // Resolve fast tokenizer path
139
- let tokenizer_path = resolve_tokenizer_path ( & args. model_name , args. revision . as_deref ( ) )
140
- . expect ( "Could not find tokenizer for model" ) ;
139
+ let config_path: PathBuf = resolve_config_path ( & args. model_name , args. revision . as_deref ( ) )
140
+ . expect ( "Failed to resolve config path" )
141
+ . into ( ) ;
142
+
143
+ // Save fast tokenizer to /tmp
144
+ let save_path = format ! ( "/tmp/{}" , Uuid :: new_v4( ) ) ;
145
+ let tokenizer_path =
146
+ if save_fast_tokenizer ( & args. model_name , args. revision . as_deref ( ) , & save_path) . is_ok ( ) {
147
+ format ! ( "/{save_path}/tokenizer.json" )
148
+ } else {
149
+ warn ! ( "Failed to (re-)convert tokenizer, falling back to use existing fast tokenizer" ) ;
150
+ let tokenizer_path = config_path. parent ( ) . unwrap ( ) . join ( "tokenizer.json" ) ;
151
+ if tokenizer_path. is_file ( ) {
152
+ tokenizer_path. to_string_lossy ( ) . to_string ( )
153
+ } else {
154
+ panic ! ( "No existing fast tokenizer (tokenizer.json) found" )
155
+ }
156
+ } ;
141
157
142
158
// Determine max sequence length based on command line arg and env vars
143
- let config_json_path = tokenizer_path. replace ( "tokenizer.json" , "config.json" ) ;
144
- let max_sequence_length = get_max_sequence_length ( args. max_sequence_length , & config_json_path) ;
159
+ let max_sequence_length = get_max_sequence_length ( args. max_sequence_length , & config_path) ;
145
160
146
161
match env:: var ( "MAX_BATCH_WEIGHT" ) {
147
162
Ok ( max_batch_weight) if !max_batch_weight. trim ( ) . is_empty ( ) => {
@@ -415,15 +430,15 @@ fn num_cuda_devices() -> Option<usize> {
415
430
/// 1. MAX_SEQUENCE_LENGTH set by user
416
431
/// 2. The sequence length specified in config.json
417
432
/// 3. A default of 2048
418
- fn get_max_sequence_length ( max_sequence_length : Option < usize > , config_json_path : & String ) -> usize {
433
+ fn get_max_sequence_length ( max_sequence_length : Option < usize > , config_path : & Path ) -> usize {
419
434
if let Some ( max_sequence_length) = max_sequence_length {
420
435
info ! (
421
436
"Using configured max_sequence_length: {}" ,
422
437
max_sequence_length
423
438
) ;
424
439
return max_sequence_length;
425
440
}
426
- if let Ok ( model_config) = get_config_json ( config_json_path ) {
441
+ if let Ok ( model_config) = get_config_json ( config_path ) {
427
442
if let Some ( length) = get_max_sequence_length_from_config ( & model_config) {
428
443
info ! (
429
444
"Loaded max_sequence_length from model config.json: {}" ,
@@ -440,7 +455,7 @@ fn get_max_sequence_length(max_sequence_length: Option<usize>, config_json_path:
440
455
}
441
456
442
457
/// Opens the model's config.json file and reads into serde_json value
443
- fn get_config_json ( config_path : & String ) -> Result < serde_json:: Value , std:: io:: Error > {
458
+ fn get_config_json ( config_path : & Path ) -> Result < serde_json:: Value , std:: io:: Error > {
444
459
let reader = BufReader :: new ( File :: open ( config_path) ?) ;
445
460
Ok ( serde_json:: from_reader ( reader) ?)
446
461
}
@@ -734,7 +749,19 @@ fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receive
734
749
let _ = shutdown_receiver. recv ( ) ;
735
750
}
736
751
737
- fn resolve_tokenizer_path ( model_name : & str , revision : Option < & str > ) -> Result < String , io:: Error > {
752
+ fn write_termination_log ( msg : & str ) -> Result < ( ) , io:: Error > {
753
+ // Writes a message to the termination log.
754
+ // Creates the logfile if it doesn't exist.
755
+ let mut f = File :: options ( )
756
+ . write ( true )
757
+ . create ( true )
758
+ . truncate ( true )
759
+ . open ( "/dev/termination-log" ) ?;
760
+ writeln ! ( f, "{}" , msg) ?;
761
+ Ok ( ( ) )
762
+ }
763
+
764
+ fn resolve_config_path ( model_name : & str , revision : Option < & str > ) -> Result < String , io:: Error > {
738
765
let cache = env:: var ( "TRANSFORMERS_CACHE" )
739
766
. or_else ( |_| env:: var ( "HUGGINGFACE_HUB_CACHE" ) )
740
767
. ok ( ) ;
@@ -753,20 +780,19 @@ fn resolve_tokenizer_path(model_name: &str, revision: Option<&str>) -> Result<St
753
780
true => fs:: read_to_string ( ref_path) ?,
754
781
false => revision. to_string ( ) ,
755
782
} ;
756
- let tok_path = dir. join ( "snapshots" ) . join ( & revision) . join ( "tokenizer .json" ) ;
757
- if tok_path . try_exists ( ) ? {
758
- Ok ( tok_path . to_string_lossy ( ) . into ( ) )
783
+ let config_path = dir. join ( "snapshots" ) . join ( & revision) . join ( "config .json" ) ;
784
+ if config_path . try_exists ( ) ? {
785
+ Ok ( config_path . to_string_lossy ( ) . into ( ) )
759
786
} else {
760
- Err ( io:: Error :: new (
761
- ErrorKind :: NotFound ,
762
- format ! (
763
- "Tokenizer file not found in local cache for model {model_name}, revision {revision}"
764
- )
765
- ) )
787
+ let message = format ! (
788
+ "Config file not found in local cache for model {model_name}, revision {revision}"
789
+ ) ;
790
+ error ! ( message) ;
791
+ Err ( io:: Error :: new ( ErrorKind :: NotFound , message) )
766
792
}
767
793
} else {
768
794
// Try treating model_name as explicit model path
769
- let try_path = Path :: new ( & model_name) . join ( "tokenizer .json" ) ;
795
+ let try_path = Path :: new ( & model_name) . join ( "config .json" ) ;
770
796
if try_path. try_exists ( ) ? {
771
797
Ok ( try_path. to_string_lossy ( ) . into ( ) )
772
798
} else {
@@ -775,20 +801,49 @@ fn resolve_tokenizer_path(model_name: &str, revision: Option<&str>) -> Result<St
775
801
} else {
776
802
format ! ( "Model {model_name} not found in local cache" )
777
803
} ;
778
-
804
+ error ! ( message ) ;
779
805
Err ( io:: Error :: new ( ErrorKind :: NotFound , message) )
780
806
}
781
807
}
782
808
}
783
809
784
- fn write_termination_log ( msg : & str ) -> Result < ( ) , io:: Error > {
785
- // Writes a message to the termination log.
786
- // Creates the logfile if it doesn't exist.
787
- let mut f = File :: options ( )
788
- . write ( true )
789
- . create ( true )
790
- . truncate ( true )
791
- . open ( "/dev/termination-log" ) ?;
792
- writeln ! ( f, "{}" , msg) ?;
793
- Ok ( ( ) )
810
+ /// Convert and save fast tokenizer via transformers `AutoTokenizer.from_pretrained`.
811
+ fn save_fast_tokenizer (
812
+ model_name : & str ,
813
+ revision : Option < & str > ,
814
+ save_path : & str ,
815
+ ) -> Result < ( ) , io:: Error > {
816
+ info ! ( "Saving fast tokenizer for `{model_name}` to `{save_path}`" ) ;
817
+ let model_name = model_name. escape_default ( ) ;
818
+ let revision = revision. map ( |v| v. escape_default ( ) ) ;
819
+ let code = if let Some ( revision) = revision {
820
+ format ! (
821
+ "from transformers import AutoTokenizer; \
822
+ AutoTokenizer.from_pretrained(\" {model_name}\" , \
823
+ revision=\" {revision}\" ).save_pretrained(\" {save_path}\" )"
824
+ )
825
+ } else {
826
+ format ! (
827
+ "from transformers import AutoTokenizer; \
828
+ AutoTokenizer.from_pretrained(\" {model_name}\" ).save_pretrained(\" {save_path}\" )"
829
+ )
830
+ } ;
831
+ match Command :: new ( "python" ) . args ( [ "-c" , & code] ) . status ( ) {
832
+ Ok ( status) => {
833
+ if status. success ( ) {
834
+ Ok ( ( ) )
835
+ } else {
836
+ let message = "Python process for tokenizer conversion failed" . to_string ( ) ;
837
+ error ! (
838
+ exit_code = ?status. code( ) ,
839
+ message
840
+ ) ;
841
+ Err ( io:: Error :: new ( ErrorKind :: Other , message) )
842
+ }
843
+ }
844
+ Err ( e) => {
845
+ error ! ( "Failed to launch python process for tokenizer conversion" ) ;
846
+ Err ( e)
847
+ }
848
+ }
794
849
}
0 commit comments