@@ -48,32 +48,31 @@ def temp_lora_dir(self) -> str:
4848 task_dir .mkdir (parents = True , exist_ok = True )
4949 yield str (lora_dir )
5050
51- def _build_base_command (self , llm_root : Path ) -> List [str ]:
51+ def _build_base_command (self , output_path : Path ) -> List [str ]:
5252 """
5353 Build the base command for running prepare_dataset.py.
5454
5555 Args:
56- llm_root : Path to the TensorRT LLM root directory
56+ output_path : Path to the output dataset file
5757
5858 Returns:
5959 List[str]: Base command components
6060
6161 Raises:
6262 pytest.skip: If LLM_MODELS_ROOT is not available
6363 """
64- script_path = llm_root / _PREPARE_DATASET_SCRIPT_PATH
65- cmd = ["python3" , str (script_path )]
64+ cmd = ["trtllm-bench" ]
6665
6766 # Add required tokenizer argument
6867 model_cache = llm_models_root ()
6968 if model_cache is None :
7069 pytest .skip ("LLM_MODELS_ROOT not available" )
7170
7271 tokenizer_dir = model_cache / _TOKENIZER_SUBPATH
73- cmd .extend (["--tokenizer " , str (tokenizer_dir )])
72+ cmd .extend (["--model " , str (tokenizer_dir )])
7473
7574 # Always add --stdout flag since we parse stdout output
76- cmd .extend (["--stdout " ])
75+ cmd .extend (["dataset" , "--output" , f" { output_path } " ])
7776
7877 return cmd
7978
@@ -109,7 +108,7 @@ def _add_synthetic_data_arguments(self, cmd: List[str]) -> None:
109108 str (_DEFAULT_OUTPUT_STDEV )
110109 ])
111110
112- def _run_prepare_dataset (self , llm_root : Path , ** kwargs ) -> str :
111+ def _run_prepare_dataset (self , ** kwargs ) -> str :
113112 """
114113 Execute prepare_dataset.py with specified parameters and capture
115114 output.
@@ -124,13 +123,20 @@ def _run_prepare_dataset(self, llm_root: Path, **kwargs) -> str:
124123 Raises:
125124 subprocess.CalledProcessError: If the command execution fails
126125 """
127- cmd = self ._build_base_command (llm_root )
128- self ._add_lora_arguments (cmd , ** kwargs )
129- self ._add_synthetic_data_arguments (cmd )
126+ with tempfile .TemporaryDirectory () as temp_dir :
127+ output_path = Path (temp_dir ) / "dataset.jsonl"
128+ cmd = self ._build_base_command (output_path )
129+ self ._add_lora_arguments (cmd , ** kwargs )
130+ self ._add_synthetic_data_arguments (cmd )
131+
132+ # Execute command and capture output
133+ subprocess .run (cmd , check = True , cwd = temp_dir )
134+
135+ data = ""
136+ with open (output_path , "r" ) as f :
137+ data = f .read ()
130138
131- # Execute command and capture output
132- result = subprocess .run (cmd , capture_output = True , text = True , check = True )
133- return result .stdout
139+ return data
134140
135141 def _parse_json_output (self , output : str ) -> List [Dict [str , Any ]]:
136142 """
@@ -198,7 +204,7 @@ def _validate_lora_request(self,
198204 },
199205 id = "random_task_id" )
200206 ])
201- def test_lora_metadata_generation (self , llm_root : Path , temp_lora_dir : str ,
207+ def test_lora_metadata_generation (self , temp_lora_dir : str ,
202208 test_params : Dict ) -> None :
203209 """Test LoRA metadata generation with various configurations."""
204210 # Extract test parameters
@@ -213,7 +219,7 @@ def test_lora_metadata_generation(self, llm_root: Path, temp_lora_dir: str,
213219 if rand_task_id is not None :
214220 kwargs ["rand_task_id" ] = rand_task_id
215221
216- output = self ._run_prepare_dataset (llm_root , ** kwargs )
222+ output = self ._run_prepare_dataset (** kwargs )
217223 json_data = self ._parse_json_output (output )
218224
219225 assert len (json_data ) > 0 , f"No JSON data generated for { description } "
0 commit comments