|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import glob |
| 16 | +import itertools |
16 | 17 | import os |
17 | 18 | import subprocess |
18 | 19 | import warnings |
@@ -78,42 +79,40 @@ def qnemo_to_tensorrt_llm( |
78 | 79 |
|
79 | 80 | speculative_decoding_mode = "medusa" if "Medusa" in config.architecture else None |
80 | 81 |
|
81 | | - build_cmd = "trtllm-build " |
82 | | - build_cmd += f"--checkpoint_dir {nemo_checkpoint_path} " |
83 | | - build_cmd += f"--log_level {log_level} " |
84 | | - build_cmd += f"--output_dir {engine_dir} " |
85 | | - build_cmd += f"--workers {num_build_workers} " |
86 | | - build_cmd += f"--max_batch_size {max_batch_size} " |
87 | | - build_cmd += f"--max_input_len {max_input_len} " |
88 | | - build_cmd += f"--max_beam_width {max_beam_width} " |
89 | | - build_cmd += f"--max_prompt_embedding_table_size {max_prompt_embedding_table_size} " |
90 | | - build_cmd += f"--paged_kv_cache {'enable' if paged_kv_cache else 'disable'} " |
91 | | - build_cmd += f"--use_paged_context_fmha {'enable' if paged_context_fmha else 'disable'} " |
92 | | - build_cmd += f"--remove_input_padding {'enable' if remove_input_padding else 'disable'} " |
93 | | - build_cmd += f"--multiple_profiles {'enable' if multiple_profiles else 'disable'} " |
94 | | - build_cmd += f"--reduce_fusion {'enable' if reduce_fusion else 'disable'} " |
95 | | - build_cmd += f"--use_fused_mlp {'enable' if use_fused_mlp else 'disable'} " |
| 82 | + build_cmd = ["trtllm-build"] |
| 83 | + build_cmd.extend(["--checkpoint_dir", nemo_checkpoint_path]) |
| 84 | + build_cmd.extend(["--log_level", log_level]) |
| 85 | + build_cmd.extend(["--output_dir", engine_dir]) |
| 86 | + build_cmd.extend(["--workers", str(num_build_workers)]) |
| 87 | + build_cmd.extend(["--max_batch_size", str(max_batch_size)]) |
| 88 | + build_cmd.extend(["--max_input_len", str(max_input_len)]) |
| 89 | + build_cmd.extend(["--max_beam_width", str(max_beam_width)]) |
| 90 | + build_cmd.extend(["--max_prompt_embedding_table_size", str(max_prompt_embedding_table_size)]) |
| 91 | + build_cmd.extend(["--paged_kv_cache", "enable" if paged_kv_cache else "disable"]) |
| 92 | + build_cmd.extend(["--use_paged_context_fmha", "enable" if paged_context_fmha else "disable"]) |
| 93 | + build_cmd.extend(["--remove_input_padding", "enable" if remove_input_padding else "disable"]) |
| 94 | + build_cmd.extend(["--multiple_profiles", "enable" if multiple_profiles else "disable"]) |
| 95 | + build_cmd.extend(["--reduce_fusion", "enable" if reduce_fusion else "disable"]) |
| 96 | + build_cmd.extend(["--use_fused_mlp", "enable" if use_fused_mlp else "disable"]) |
96 | 97 |
|
97 | 98 | if not use_qdq: |
98 | | - build_cmd += "--gemm_plugin auto " |
| 99 | + build_cmd.extend(["--gemm_plugin", "auto"]) |
99 | 100 |
|
100 | 101 | if max_seq_len is not None: |
101 | | - build_cmd += f"--max_seq_len {max_seq_len} " |
| 102 | + build_cmd.extend(["--max_seq_len", str(max_seq_len)]) |
102 | 103 |
|
103 | 104 | if max_num_tokens is not None: |
104 | | - build_cmd += f"--max_num_tokens {max_num_tokens} " |
| 105 | + build_cmd.extend(["--max_num_tokens", str(max_num_tokens)]) |
105 | 106 | else: |
106 | | - build_cmd += f"--max_num_tokens {max_batch_size * max_input_len} " |
| 107 | + build_cmd.extend(["--max_num_tokens", str(max_batch_size * max_input_len)]) |
107 | 108 |
|
108 | 109 | if opt_num_tokens is not None: |
109 | | - build_cmd += f"--opt_num_tokens {opt_num_tokens} " |
| 110 | + build_cmd.extend(["--opt_num_tokens", str(opt_num_tokens)]) |
110 | 111 |
|
111 | 112 | if speculative_decoding_mode: |
112 | | - build_cmd += f"--speculative_decoding_mode {speculative_decoding_mode} " |
113 | | - |
114 | | - build_cmd = build_cmd.replace("--", "\\\n --") # Separate parameters line by line |
| 113 | + build_cmd.extend(["--speculative_decoding_mode", speculative_decoding_mode]) |
115 | 114 |
|
116 | 115 | print("trtllm-build command:") |
117 | | - print(build_cmd) |
| 116 | + print("".join(itertools.chain.from_iterable(zip(build_cmd, itertools.cycle(["\n ", " "])))).strip()) |
118 | 117 |
|
119 | | - subprocess.run(build_cmd, shell=True, check=True) |
| 118 | + subprocess.run(build_cmd, shell=False, check=True) |
0 commit comments