@@ -256,6 +256,7 @@ struct cmd_params {
256256 std::vector<bool > embeddings;
257257 std::vector<llama_model_tensor_buft_override> buft_overrides;
258258 ggml_numa_strategy numa;
259+ std::string cuda_params;
259260 int reps;
260261 bool verbose;
261262 bool warmup;
@@ -295,6 +296,7 @@ static const cmd_params cmd_params_defaults = {
295296 /* embeddings */ {false },
296297 /* buft_overrides */ {},
297298 /* numa */ GGML_NUMA_STRATEGY_DISABLED,
299+ /* cuda_params */ {},
298300 /* reps */ 5 ,
299301 /* verbose */ false ,
300302 /* warmup */ true ,
@@ -344,6 +346,7 @@ static void print_usage(int /* argc */, char ** argv) {
344346 printf (" -v, --verbose (default: %s)\n " , cmd_params_defaults.verbose ? " 1" : " 0" );
345347 printf (" -w, --warmup <0|1> (default: %s)\n " , cmd_params_defaults.warmup ? " 1" : " 0" );
346348 printf (" -rtr, --run-time-repack <0|1> (default: %s)\n " , cmd_params_defaults.repack ? " 1" : " 0" );
349+ printf (" -cuda, --cuda-params <string> (default: %s)\n " , cmd_params_defaults.repack ? " 1" : " 0" );
347350 printf (" -mqkv, --merge-qkv (default: %s)\n " , cmd_params_defaults.mqkv ? " 1" : " 0" );
348351 printf (" -thp, --transparent-huge-pages <0|1> (default: %s)\n " , cmd_params_defaults.use_thp ? " 1" : " 0" );
349352 printf (" -ot, --override-tensor pattern (default: none)\n " );
@@ -736,6 +739,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
736739 break ;
737740 }
738741 params.repack = std::stoi (argv[i]);
742+ } else if (arg == " -cuda" || arg == " --cuda-params" ) {
743+ if (++i >= argc) {
744+ invalid_param = true ;
745+ break ;
746+ }
747+ params.cuda_params = argv[i];
739748 } else if (arg == " -mqkv" || arg == " --merge-qkv" ) {
740749 if (++i >= argc) {
741750 invalid_param = true ;
@@ -852,6 +861,7 @@ struct cmd_params_instance {
852861 int attn_max_batch;
853862 Ser ser;
854863 std::vector<float > tensor_split;
864+ std::string cuda_params;
855865 bool use_mmap;
856866 bool embeddings;
857867 bool repack = false ;
@@ -914,6 +924,7 @@ struct cmd_params_instance {
914924 cparams.min_experts = ser.first ;
915925 cparams.thresh_experts = ser.second ;
916926 cparams.embeddings = embeddings;
927+ cparams.cuda_params = (void *)cuda_params.data ();
917928
918929 return cparams;
919930 }
@@ -965,6 +976,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
965976 /* .attn_max_b = */ amb,
966977 /* .ser = */ ser,
967978 /* .tensor_split = */ ts,
979+ /* .cuda_params = */ params.cuda_params ,
968980 /* .use_mmap = */ mmp,
969981 /* .embeddings = */ embd,
970982 /* .repack = */ params.repack ,
@@ -1003,6 +1015,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10031015 /* .attn_max_b = */ amb,
10041016 /* .ser = */ ser,
10051017 /* .tensor_split = */ ts,
1018+ /* .cuda_params = */ params.cuda_params ,
10061019 /* .use_mmap = */ mmp,
10071020 /* .embeddings = */ embd,
10081021 /* .repack = */ params.repack ,
@@ -1041,6 +1054,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10411054 /* .attn_max_b = */ amb,
10421055 /* .ser = */ ser,
10431056 /* .tensor_split = */ ts,
1057+ /* .cuda_params = */ params.cuda_params ,
10441058 /* .use_mmap = */ mmp,
10451059 /* .embeddings = */ embd,
10461060 /* .repack = */ params.repack ,
@@ -1079,6 +1093,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10791093 /* .attn_max_b = */ amb,
10801094 /* .ser = */ ser,
10811095 /* .tensor_split = */ ts,
1096+ /* .cuda_params = */ params.cuda_params ,
10821097 /* .use_mmap = */ mmp,
10831098 /* .embeddings = */ embd,
10841099 /* .repack = */ params.repack ,
@@ -1128,6 +1143,7 @@ struct test {
11281143 int attn_max_batch;
11291144 Ser ser;
11301145 std::vector<float > tensor_split;
1146+ std::string cuda_params;
11311147 bool use_mmap;
11321148 bool embeddings;
11331149 bool repack = false ;
@@ -1166,6 +1182,7 @@ struct test {
11661182 attn_max_batch = inst.attn_max_batch ;
11671183 ser = inst.ser ;
11681184 tensor_split = inst.tensor_split ;
1185+ cuda_params = inst.cuda_params ;
11691186 use_mmap = inst.use_mmap ;
11701187 embeddings = inst.embeddings ;
11711188 repack = inst.repack ;
0 commit comments