2222#endif
2323
2424enum split_operation : uint8_t {
25- SPLIT_OP_SPLIT,
26- SPLIT_OP_MERGE,
25+ OP_NONE,
26+ OP_SPLIT,
27+ OP_MERGE,
28+ };
29+
30+ enum split_mode : uint8_t {
31+ MODE_NONE,
32+ MODE_TENSOR,
33+ MODE_SIZE,
2734};
2835
2936struct split_params {
30- split_operation operation = SPLIT_OP_SPLIT;
37+ split_operation operation = OP_NONE;
38+ split_mode mode = MODE_NONE;
3139 size_t n_bytes_split = 0 ;
3240 int n_split_tensors = 128 ;
3341 std::string input;
@@ -80,8 +88,6 @@ static void split_params_parse_ex(int argc, const char ** argv, split_params & p
8088 bool invalid_param = false ;
8189
8290 int arg_idx = 1 ;
83- bool is_op_set = false ;
84- bool is_mode_set = false ;
8591 for (; arg_idx < argc && strncmp (argv[arg_idx], " --" , 2 ) == 0 ; arg_idx++) {
8692 arg = argv[arg_idx];
8793 if (arg.compare (0 , arg_prefix.size (), arg_prefix) == 0 ) {
@@ -92,54 +98,49 @@ static void split_params_parse_ex(int argc, const char ** argv, split_params & p
9298 if (arg == " -h" || arg == " --help" ) {
9399 split_print_usage (argv[0 ]);
94100 exit (0 );
95- }
96- if (arg == " --version" ) {
101+ } else if (arg == " --version" ) {
97102 fprintf (stderr, " version: %d (%s)\n " , LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
98103 fprintf (stderr, " built with %s for %s\n " , LLAMA_COMPILER, LLAMA_BUILD_TARGET);
99104 exit (0 );
100- }
101- if (arg == " --dry-run" ) {
105+ } else if (arg == " --dry-run" ) {
102106 arg_found = true ;
103107 params.dry_run = true ;
104- }
105- if (arg == " --no-tensor-first-split" ) {
108+ } else if (arg == " --no-tensor-first-split" ) {
106109 arg_found = true ;
107110 params.no_tensor_first_split = true ;
108- }
109-
110- if (is_op_set) {
111- throw std::invalid_argument (" error: either --split or --merge can be specified, but not both" );
112- }
113- if (arg == " --merge" ) {
111+ } else if (arg == " --merge" ) {
114112 arg_found = true ;
115- is_op_set = true ;
116- params.operation = SPLIT_OP_MERGE;
117- }
118- if (arg == " --split" ) {
113+ if (params.operation != OP_NONE && params.operation != OP_MERGE) {
114+ throw std::invalid_argument (" error: either --split or --merge can be specified, but not both" );
115+ }
116+ params.operation = OP_MERGE;
117+ } else if (arg == " --split" ) {
119118 arg_found = true ;
120- is_op_set = true ;
121- params.operation = SPLIT_OP_SPLIT;
122- }
123-
124- if (is_mode_set) {
125- throw std::invalid_argument (" error: either --split-max-tensors or --split-max-size can be specified, but not both" );
126- }
127- if (arg == " --split-max-tensors" ) {
119+ if (params.operation != OP_NONE && params.operation != OP_SPLIT) {
120+ throw std::invalid_argument (" error: either --split or --merge can be specified, but not both" );
121+ }
122+ params.operation = OP_SPLIT;
123+ } else if (arg == " --split-max-tensors" ) {
128124 if (++arg_idx >= argc) {
129125 invalid_param = true ;
130126 break ;
131127 }
132128 arg_found = true ;
133- is_mode_set = true ;
129+ if (params.mode != MODE_NONE && params.mode != MODE_TENSOR) {
130+ throw std::invalid_argument (" error: either --split-max-tensors or --split-max-size can be specified, but not both" );
131+ }
132+ params.mode = MODE_TENSOR;
134133 params.n_split_tensors = atoi (argv[arg_idx]);
135- }
136- if (arg == " --split-max-size" ) {
134+ } else if (arg == " --split-max-size" ) {
137135 if (++arg_idx >= argc) {
138136 invalid_param = true ;
139137 break ;
140138 }
141139 arg_found = true ;
142- is_mode_set = true ;
140+ if (params.mode != MODE_NONE && params.mode != MODE_SIZE) {
141+ throw std::invalid_argument (" error: either --split-max-tensors or --split-max-size can be specified, but not both" );
142+ }
143+ params.mode = MODE_SIZE;
143144 params.n_bytes_split = split_str_to_n_bytes (argv[arg_idx]);
144145 }
145146
@@ -148,6 +149,15 @@ static void split_params_parse_ex(int argc, const char ** argv, split_params & p
148149 }
149150 }
150151
152+ // the operation is split if not specified
153+ if (params.operation == OP_NONE) {
154+ params.operation = OP_SPLIT;
155+ }
156+ // the split mode is by tensor if not specified
157+ if (params.mode == MODE_NONE) {
158+ params.mode = MODE_TENSOR;
159+ }
160+
151161 if (invalid_param) {
152162 throw std::invalid_argument (" error: invalid parameter for argument: " + arg);
153163 }
@@ -265,13 +275,15 @@ struct split_strategy {
265275 }
266276
267277 bool should_split (int i_tensor, size_t next_size) {
268- if (params.n_bytes_split > 0 ) {
278+ if (params.mode == MODE_SIZE ) {
269279 // split by max size per file
270280 return next_size > params.n_bytes_split ;
271- } else {
281+ } else if (params. mode == MODE_TENSOR) {
272282 // split by number of tensors per file
273283 return i_tensor > 0 && i_tensor < n_tensors && i_tensor % params.n_split_tensors == 0 ;
274284 }
285+ // should never happen
286+ return false ;
275287 }
276288
277289 void print_info () {
@@ -559,9 +571,9 @@ int main(int argc, const char ** argv) {
559571 split_params_parse (argc, argv, params);
560572
561573 switch (params.operation ) {
562- case SPLIT_OP_SPLIT : gguf_split (params);
574+ case OP_SPLIT : gguf_split (params);
563575 break ;
564- case SPLIT_OP_MERGE : gguf_merge (params);
576+ case OP_MERGE : gguf_merge (params);
565577 break ;
566578 default : split_print_usage (argv[0 ]);
567579 exit (EXIT_FAILURE);
0 commit comments