Skip to content

Commit 619ddbb

Browse files
Refactor test utils (#476)
Signed-off-by: Keval Morabia <[email protected]>
1 parent fc22673 commit 619ddbb

File tree

129 files changed

+675
-671
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

129 files changed

+675
-671
lines changed

tests/_test_utils/ptq_utils.py renamed to tests/_test_utils/examples/llm_ptq_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pytest
2222
import torch
2323

24-
PTQ_EXAMPLE_DIR = Path(__file__).parents[2] / "examples" / "llm_ptq"
24+
PTQ_EXAMPLE_DIR = Path(__file__).parents[3] / "examples" / "llm_ptq"
2525

2626

2727
@dataclass
File renamed without changes.

tests/_test_utils/examples/run_command.py

Lines changed: 18 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,18 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
"""Utility functions for running example commands reused in multiple example tests."""
1516

1617
import os
1718
import subprocess
18-
import time
1919
from pathlib import Path
2020

21-
from _test_utils.torch_dist.dist_utils import get_free_port
21+
from _test_utils.torch.distributed.utils import get_free_port
2222

23-
MODELOPT_ROOT = Path(__file__).parent.parent.parent.parent
23+
MODELOPT_ROOT = Path(__file__).parents[3]
2424

2525

26-
def _extend_cmd_parts(cmd_parts: list[str], **kwargs):
26+
def extend_cmd_parts(cmd_parts: list[str], **kwargs):
2727
for key, value in kwargs.items():
2828
if value is not None:
2929
cmd_parts.extend([f"--{key}", str(value)])
@@ -32,9 +32,14 @@ def _extend_cmd_parts(cmd_parts: list[str], **kwargs):
3232
return cmd_parts
3333

3434

35-
def run_example_command(cmd_parts: list[str], example_path: str, setup_free_port: bool = False):
35+
def run_example_command(
36+
cmd_parts: list[str],
37+
example_path: str,
38+
setup_free_port: bool = False,
39+
env: dict[str, str] | None = None,
40+
):
3641
print(f"[{example_path}] Running command: {cmd_parts}")
37-
env = os.environ.copy()
42+
env = env or os.environ.copy()
3843

3944
if setup_free_port:
4045
free_port = get_free_port()
@@ -43,7 +48,9 @@ def run_example_command(cmd_parts: list[str], example_path: str, setup_free_port
4348
subprocess.run(cmd_parts, cwd=MODELOPT_ROOT / "examples" / example_path, env=env, check=True)
4449

4550

46-
def run_command_in_background(cmd_parts, example_path, stdout=None, stderr=None, text=True):
51+
def run_command_in_background(
52+
cmd_parts: list[str], example_path: str, stdout=None, stderr=None, text=True
53+
):
4754
print(f"Running command in background: {' '.join(str(part) for part in cmd_parts)}")
4855
process = subprocess.Popen(
4956
cmd_parts,
@@ -55,57 +62,7 @@ def run_command_in_background(cmd_parts, example_path, stdout=None, stderr=None,
5562
return process
5663

5764

58-
def run_llm_autodeploy_command(
59-
model: str, quant: str, effective_bits: float, output_dir: str, **kwargs
60-
):
61-
# Create temporary directory for saving the quantized checkpoint
62-
port = get_free_port()
63-
quantized_ckpt_dir = os.path.join(output_dir, "quantized_model")
64-
kwargs.update(
65-
{
66-
"hf_ckpt": model,
67-
"quant": quant,
68-
"effective_bits": effective_bits,
69-
"save_quantized_ckpt": quantized_ckpt_dir,
70-
"port": port,
71-
}
72-
)
73-
74-
server_handler = None
75-
try:
76-
# Quantize and deploy the model to the background
77-
cmd_parts = _extend_cmd_parts(["scripts/run_auto_quant_and_deploy.sh"], **kwargs)
78-
# Pass None to stdout and stderr to see the output in the console
79-
server_handler = run_command_in_background(
80-
cmd_parts, "llm_autodeploy", stdout=None, stderr=None
81-
)
82-
83-
# Wait for the server to start. We might need to build
84-
time.sleep(100)
85-
86-
# Test the deployment
87-
run_example_command(
88-
["python", "api_client.py", "--prompt", "What is AI?", "--port", str(port)],
89-
"llm_autodeploy",
90-
)
91-
finally:
92-
if server_handler:
93-
server_handler.terminate()
94-
95-
96-
def run_torch_onnx_command(*, quantize_mode: str, onnx_save_path: str, calib_size: str, **kwargs):
97-
kwargs.update(
98-
{
99-
"quantize_mode": quantize_mode,
100-
"onnx_save_path": onnx_save_path,
101-
"calibration_data_size": calib_size,
102-
}
103-
)
104-
cmd_parts = _extend_cmd_parts(["python", "torch_quant_to_onnx.py"], **kwargs)
105-
run_example_command(cmd_parts, "onnx_ptq")
106-
107-
108-
def run_llm_export_command(
65+
def run_onnx_llm_export_command(
10966
*, torch_dir: str, dtype: str, lm_head: str, output_dir: str, calib_size: str, **kwargs
11067
):
11168
kwargs.update(
@@ -117,7 +74,7 @@ def run_llm_export_command(
11774
"calib_size": calib_size,
11875
}
11976
)
120-
cmd_parts = _extend_cmd_parts(["python", "llm_export.py"], **kwargs)
77+
cmd_parts = extend_cmd_parts(["python", "llm_export.py"], **kwargs)
12178
run_example_command(cmd_parts, "onnx_ptq")
12279

12380

@@ -126,7 +83,7 @@ def run_llm_ptq_command(*, model: str, quant: str, **kwargs):
12683
kwargs.setdefault("tasks", "quant")
12784
kwargs.setdefault("calib", 16)
12885

129-
cmd_parts = _extend_cmd_parts(["scripts/huggingface_example.sh", "--no-verbose"], **kwargs)
86+
cmd_parts = extend_cmd_parts(["scripts/huggingface_example.sh", "--no-verbose"], **kwargs)
13087
run_example_command(cmd_parts, "llm_ptq")
13188

13289

@@ -135,44 +92,5 @@ def run_vlm_ptq_command(*, model: str, quant: str, **kwargs):
13592
kwargs.setdefault("tasks", "quant")
13693
kwargs.setdefault("calib", 16)
13794

138-
cmd_parts = _extend_cmd_parts(["scripts/huggingface_example.sh"], **kwargs)
95+
cmd_parts = extend_cmd_parts(["scripts/huggingface_example.sh"], **kwargs)
13996
run_example_command(cmd_parts, "vlm_ptq")
140-
141-
142-
def run_diffusers_cmd(cmd_parts: list[str]):
143-
run_example_command(cmd_parts, "diffusers/quantization")
144-
145-
146-
def run_llm_sparsity_command(
147-
*, model: str, output_dir: str, sparsity_fmt: str = "sparsegpt", **kwargs
148-
):
149-
kwargs.update(
150-
{"model_name_or_path": model, "sparsity_fmt": sparsity_fmt, "output_dir": output_dir}
151-
)
152-
kwargs.setdefault("calib_size", 16)
153-
kwargs.setdefault("device", "cuda")
154-
kwargs.setdefault("dtype", "fp16")
155-
kwargs.setdefault("model_max_length", 1024)
156-
157-
cmd_parts = _extend_cmd_parts(["python", "hf_pts.py"], **kwargs)
158-
run_example_command(cmd_parts, "llm_sparsity")
159-
160-
161-
def run_llm_sparsity_ft_command(
162-
*, model: str, restore_path: str, output_dir: str, data_path: str, **kwargs
163-
):
164-
kwargs.update(
165-
{
166-
"model": model,
167-
"restore_path": restore_path,
168-
"output_dir": output_dir,
169-
"data_path": data_path,
170-
}
171-
)
172-
kwargs.setdefault("num_epochs", 0.01)
173-
kwargs.setdefault("max_length", 128)
174-
kwargs.setdefault("train_bs", 1)
175-
kwargs.setdefault("eval_bs", 1)
176-
177-
cmd_parts = _extend_cmd_parts(["bash", "launch_finetune.sh"], **kwargs)
178-
run_example_command(cmd_parts, "llm_sparsity")

tests/_test_utils/onnx_quantization/utils.py renamed to tests/_test_utils/onnx/quantization/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import onnx_graphsurgeon as gs
1717

1818

19-
def _assert_nodes_are_quantized(nodes):
19+
def assert_nodes_are_quantized(nodes):
2020
for node in nodes:
2121
for inp_idx, inp in enumerate(node.inputs):
2222
if isinstance(inp, gs.Variable):

tests/_test_utils/torch_deploy/device_model.py renamed to tests/_test_utils/torch/deploy/device_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import pytest
2020
import torch
21-
from _test_utils.torch_model.deploy_models import BaseDeployModel
21+
from _test_utils.torch.deploy.lib_test_models import BaseDeployModel
2222

2323
from modelopt.torch._deploy import compile
2424
from modelopt.torch._deploy.utils.torch_onnx import _to_expected_onnx_type

0 commit comments

Comments
 (0)