diff --git a/.azure-pipelines/scripts/ut/run_ut_xpu.sh b/.azure-pipelines/scripts/ut/run_ut_xpu.sh index eb2caa5f9..405c72b1c 100644 --- a/.azure-pipelines/scripts/ut/run_ut_xpu.sh +++ b/.azure-pipelines/scripts/ut/run_ut_xpu.sh @@ -24,11 +24,24 @@ ut_log_name=${LOG_DIR}/ut.log find ./test_ark -name "test*.py" | sed "s,\.\/,python -m pytest --cov=\"${auto_round_path}\" --cov-report term --html=report.html --self-contained-html --cov-report xml:coverage.xml --cov-append -vs --disable-warnings ,g" > run_ark.sh cat run_ark.sh find ./test_xpu -name "test*.py" | sed "s,\.\/,python -m pytest --cov=\"${auto_round_path}\" --cov-report term --html=report.html --self-contained-html --cov-report xml:coverage.xml --cov-append -vs --disable-warnings ,g" > run_xpu.sh +sed -i "/test_llmc_integration.py/d" run_xpu.sh cat run_xpu.sh +find ./test_xpu -name "test_llmc_integration.py" | sed "s,\.\/,python -m pytest --cov=\"${auto_round_path}\" --cov-report term --html=report.html --self-contained-html --cov-report xml:coverage.xml --cov-append -vs --disable-warnings ,g" > run_xpu_llmc.sh +cat run_xpu_llmc.sh +echo "##[group]Run xpu test on xpu..." numactl --physcpubind="${NUMA_CPUSET:-0-27}" --membind="${NUMA_NODE:-0}" bash run_xpu.sh 2>&1 | tee "${ut_log_name}" +echo "##[endgroup]" +echo "##[group]Run Ark test on xpu..." numactl --physcpubind="${NUMA_CPUSET:-0-27}" --membind="${NUMA_NODE:-0}" bash run_ark.sh 2>&1 | tee -a "${ut_log_name}" +echo "##[endgroup]" + +echo "##[group]Run LLMC integration test on xpu..." +uv pip install -r ./test_xpu/requirements_llmc.txt +uv pip list +numactl --physcpubind="${NUMA_CPUSET:-0-27}" --membind="${NUMA_NODE:-0}" bash run_xpu_llmc.sh 2>&1 | tee -a "${ut_log_name}" +echo "##[endgroup]" cp report.html ${LOG_DIR}/ cp coverage.xml ${LOG_DIR}/ diff --git a/test/test_cuda/integrations/test_llmc_integration.py b/test/test_cuda/integrations/test_llmc_integration.py deleted file mode 120000 index 3422e3cdc..000000000 --- a/test/test_cuda/integrations/test_llmc_integration.py +++ /dev/null @@ -1 +0,0 @@ -../../test_cpu/integrations/test_llmc_integration.py \ No newline at end of file diff --git a/test/test_cuda/integrations/test_llmc_integration.py b/test/test_cuda/integrations/test_llmc_integration.py new file mode 100644 index 000000000..c98faf7bf --- /dev/null +++ b/test/test_cuda/integrations/test_llmc_integration.py @@ -0,0 +1,239 @@ +import pytest +import torch +from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme +from llmcompressor import oneshot +from llmcompressor.modifiers.autoround import AutoRoundModifier +from transformers import AutoModelForCausalLM, AutoTokenizer + +from auto_round.calib_dataset import get_dataset + +recipe_str = """ +quant_stage: + quant_modifiers: + AutoRoundModifier: + ignore: ["lm_head"] + iters: 10 + config_groups: + group_0: + targets: + - "Linear" + input_activations: null + output_activations: null + weights: + num_bits: 4 + type: "int" + symmetric: true + strategy: group + group_size: 128 +""" + +recipe_modifier_full = AutoRoundModifier( + ignore=["lm_head"], + iters=10, + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs(num_bits=4, strategy="group", group_size=128), + ) + }, +) +recipe_modifier_nvfp4 = AutoRoundModifier( + ignore=["lm_head"], + iters=2, + scheme="NVFP4", +) + +recipe_modifier_mxfp4 = AutoRoundModifier( + ignore=["lm_head"], + iters=0, + scheme="MXFP4", +) + +w8a8_dynamic_recipe_modifier = AutoRoundModifier( + ignore=["lm_head"], + iters=0, + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs(num_bits=8, type="float", strategy="channel"), + input_activations=QuantizationArgs(num_bits=8, type="float", strategy="token", dynamic=True), + ) + }, +) + +w8a8_static_recipe_modifier = AutoRoundModifier( + ignore=["lm_head"], + iters=0, + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs(num_bits=8, type="float", strategy="tensor"), + input_activations=QuantizationArgs(num_bits=8, type="float", strategy="tensor"), + ) + }, +) + + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires at least 1 Cuda GPU") +@pytest.mark.parametrize( + "recipe", + [ + recipe_str, + recipe_modifier_full, + recipe_modifier_nvfp4, + recipe_modifier_mxfp4, + ], +) +def test_oneshot_application(recipe, tmp_path): + output = tmp_path / "oneshot_output" + model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + tokenizer = AutoTokenizer.from_pretrained(model) + dataset = get_dataset( + tokenizer=tokenizer, + seqlen=1024, + nsamples=32, + ) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + + oneshot( + model=model, + dataset=dataset, + output_dir=output, + recipe=recipe, + ) + model_loaded = AutoModelForCausalLM.from_pretrained(output, device_map=device) + + # Check that the model is quantized + # decompress() will attach a quantization_config to the model + # as we decompress right away + quantization_config = model_loaded.config.quantization_config.quantization_config + assert quantization_config is not None + + # check config is set properly + assert "lm_head" in quantization_config.ignore + assert len(quantization_config.config_groups) == 1 + quant_scheme = quantization_config.config_groups["group_0"] + assert isinstance(quant_scheme, QuantizationScheme) + + weight_args = quantization_config.config_groups["group_0"].weights + assert isinstance(weight_args, QuantizationArgs) + assert weight_args.num_bits == 4 + + # Check a specific layer is quantized + targeted_linear_layer = model_loaded.model.layers[2].self_attn.q_proj + assert hasattr(targeted_linear_layer, "quantization_scheme") + + # Check lm-head is not quantized + not_targeted = model_loaded.lm_head + assert not hasattr(not_targeted, "quantization_scheme") + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires at least 2 Cuda GPUs") +def test_oneshot_with_device_ids(tmp_path): + output = tmp_path / "oneshot_output" + model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + tokenizer = AutoTokenizer.from_pretrained(model) + dataset = get_dataset( + tokenizer=tokenizer, + seqlen=512, + nsamples=4, + ) + + device = "cuda:0" + + recipe = AutoRoundModifier( + ignore=["lm_head"], + iters=10, + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs(num_bits=4, strategy="group", group_size=128), + ) + }, + device_ids="0,1", + ) + + oneshot( + model=model, + dataset=dataset, + output_dir=output, + recipe=recipe, + ) + model_loaded = AutoModelForCausalLM.from_pretrained(output, device_map=device) + + # Check that the model is quantized + # decompress() will attach a quantization_config to the model + # as we decompress right away + quantization_config = model_loaded.config.quantization_config.quantization_config + assert quantization_config is not None + + # check config is set properly + assert "lm_head" in quantization_config.ignore + assert len(quantization_config.config_groups) == 1 + quant_scheme = quantization_config.config_groups["group_0"] + assert isinstance(quant_scheme, QuantizationScheme) + + weight_args = quantization_config.config_groups["group_0"].weights + assert isinstance(weight_args, QuantizationArgs) + assert weight_args.num_bits == 4 + + # Check a specific layer is quantized + targeted_linear_layer = model_loaded.model.layers[2].self_attn.q_proj + assert hasattr(targeted_linear_layer, "quantization_scheme") + + # Check lm-head is not quantized + not_targeted = model_loaded.lm_head + assert not hasattr(not_targeted, "quantization_scheme") + + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires at least 1 Cuda GPU") +@pytest.mark.parametrize( + "recipe", + [w8a8_dynamic_recipe_modifier, w8a8_static_recipe_modifier], +) +def test_rtn_oneshot(recipe, tmp_path): + output = tmp_path / "oneshot_output" + model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + tokenizer = AutoTokenizer.from_pretrained(model) + dataset = get_dataset( + tokenizer=tokenizer, + seqlen=1024, + nsamples=32, + ) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + + oneshot( + model=model, + dataset=dataset, + output_dir=output, + recipe=recipe, + ) + model_loaded = AutoModelForCausalLM.from_pretrained(output, device_map=device) + + quantization_config = model_loaded.config.quantization_config.quantization_config + assert quantization_config is not None + + # check config is set properly + assert "lm_head" in quantization_config.ignore + assert len(quantization_config.config_groups) == 1 + quant_scheme = quantization_config.config_groups["group_0"] + assert isinstance(quant_scheme, QuantizationScheme) + + weight_args = quantization_config.config_groups["group_0"].weights + act_args = quantization_config.config_groups["group_0"].input_activations + assert isinstance(weight_args, QuantizationArgs) + assert weight_args.num_bits == recipe.config_groups["group_0"].weights.num_bits + assert weight_args.strategy == recipe.config_groups["group_0"].weights.strategy + if act_args is not None: + assert act_args.num_bits == recipe.config_groups["group_0"].input_activations.num_bits + assert act_args.strategy == recipe.config_groups["group_0"].input_activations.strategy + + # Check a specific layer is quantized + targeted_linear_layer = model_loaded.model.layers[2].self_attn.q_proj + assert hasattr(targeted_linear_layer, "quantization_scheme") + + # Check lm-head is not quantized + not_targeted = model_loaded.lm_head + assert not hasattr(not_targeted, "quantization_scheme") diff --git a/test/test_xpu/requirements_llmc.txt b/test/test_xpu/requirements_llmc.txt new file mode 100644 index 000000000..0af08f61e --- /dev/null +++ b/test/test_xpu/requirements_llmc.txt @@ -0,0 +1 @@ +llmcompressor @ git+https://github.com/vllm-project/llm-compressor.git@main diff --git a/test/test_xpu/test_llmc_integration.py b/test/test_xpu/test_llmc_integration.py new file mode 100644 index 000000000..2b0310306 --- /dev/null +++ b/test/test_xpu/test_llmc_integration.py @@ -0,0 +1,233 @@ +import pytest +import torch +from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme +from llmcompressor import oneshot +from llmcompressor.modifiers.autoround import AutoRoundModifier +from transformers import AutoModelForCausalLM, AutoTokenizer + +from auto_round.calib_dataset import get_dataset + +recipe_str = """ +quant_stage: + quant_modifiers: + AutoRoundModifier: + ignore: ["lm_head"] + iters: 10 + config_groups: + group_0: + targets: + - "Linear" + input_activations: null + output_activations: null + weights: + num_bits: 4 + type: "int" + symmetric: true + strategy: group + group_size: 128 +""" + +recipe_modifier_full = AutoRoundModifier( + ignore=["lm_head"], + iters=10, + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs(num_bits=4, strategy="group", group_size=128), + ) + }, +) +recipe_modifier_nvfp4 = AutoRoundModifier( + ignore=["lm_head"], + iters=2, + scheme="NVFP4", +) + +recipe_modifier_mxfp4 = AutoRoundModifier( + ignore=["lm_head"], + iters=0, + scheme="MXFP4", +) + +w8a8_dynamic_recipe_modifier = AutoRoundModifier( + ignore=["lm_head"], + iters=0, + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs(num_bits=8, type="float", strategy="channel"), + input_activations=QuantizationArgs(num_bits=8, type="float", strategy="token", dynamic=True), + ) + }, +) + +w8a8_static_recipe_modifier = AutoRoundModifier( + ignore=["lm_head"], + iters=0, + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs(num_bits=8, type="float", strategy="tensor"), + input_activations=QuantizationArgs(num_bits=8, type="float", strategy="tensor"), + ) + }, +) + + +@pytest.mark.skipif(torch.xpu.device_count() < 1, reason="test requires at least 1 XPU") +@pytest.mark.parametrize( + "recipe", + [ + recipe_str, + recipe_modifier_full, + recipe_modifier_nvfp4, + recipe_modifier_mxfp4, + ], +) +def test_oneshot_application(recipe, tmp_path): + output = tmp_path / "oneshot_output" + model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + tokenizer = AutoTokenizer.from_pretrained(model) + dataset = get_dataset( + tokenizer=tokenizer, + seqlen=1024, + nsamples=32, + ) + + oneshot( + model=model, + dataset=dataset, + output_dir=output, + recipe=recipe, + ) + model_loaded = AutoModelForCausalLM.from_pretrained(output, device_map="xpu") + + # Check that the model is quantized + # decompress() will attach a quantization_config to the model + # as we decompress right away + quantization_config = model_loaded.config.quantization_config.quantization_config + assert quantization_config is not None + + # check config is set properly + assert "lm_head" in quantization_config.ignore + assert len(quantization_config.config_groups) == 1 + quant_scheme = quantization_config.config_groups["group_0"] + assert isinstance(quant_scheme, QuantizationScheme) + + weight_args = quantization_config.config_groups["group_0"].weights + assert isinstance(weight_args, QuantizationArgs) + assert weight_args.num_bits == 4 + + # Check a specific layer is quantized + targeted_linear_layer = model_loaded.model.layers[2].self_attn.q_proj + assert hasattr(targeted_linear_layer, "quantization_scheme") + + # Check lm-head is not quantized + not_targeted = model_loaded.lm_head + assert not hasattr(not_targeted, "quantization_scheme") + + +@pytest.mark.skipif(torch.xpu.device_count() < 2, reason="test requires at least 2 XPUs") +def test_oneshot_with_device_ids(tmp_path): + output = tmp_path / "oneshot_output" + model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + tokenizer = AutoTokenizer.from_pretrained(model) + dataset = get_dataset( + tokenizer=tokenizer, + seqlen=512, + nsamples=4, + ) + + recipe = AutoRoundModifier( + ignore=["lm_head"], + iters=10, + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs(num_bits=4, strategy="group", group_size=128), + ) + }, + device_ids="0,1", + ) + + oneshot( + model=model, + dataset=dataset, + output_dir=output, + recipe=recipe, + ) + model_loaded = AutoModelForCausalLM.from_pretrained(output, device_map="xpu") + + # Check that the model is quantized + # decompress() will attach a quantization_config to the model + # as we decompress right away + quantization_config = model_loaded.config.quantization_config.quantization_config + assert quantization_config is not None + + # check config is set properly + assert "lm_head" in quantization_config.ignore + assert len(quantization_config.config_groups) == 1 + quant_scheme = quantization_config.config_groups["group_0"] + assert isinstance(quant_scheme, QuantizationScheme) + + weight_args = quantization_config.config_groups["group_0"].weights + assert isinstance(weight_args, QuantizationArgs) + assert weight_args.num_bits == 4 + + # Check a specific layer is quantized + targeted_linear_layer = model_loaded.model.layers[2].self_attn.q_proj + assert hasattr(targeted_linear_layer, "quantization_scheme") + + # Check lm-head is not quantized + not_targeted = model_loaded.lm_head + assert not hasattr(not_targeted, "quantization_scheme") + + +@pytest.mark.skipif(torch.xpu.device_count() < 1, reason="test requires at least 1 XPU") +@pytest.mark.parametrize( + "recipe", + [w8a8_dynamic_recipe_modifier, w8a8_static_recipe_modifier], +) +def test_rtn_oneshot(recipe, tmp_path): + output = tmp_path / "oneshot_output" + model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + tokenizer = AutoTokenizer.from_pretrained(model) + dataset = get_dataset( + tokenizer=tokenizer, + seqlen=1024, + nsamples=32, + ) + + oneshot( + model=model, + dataset=dataset, + output_dir=output, + recipe=recipe, + ) + model_loaded = AutoModelForCausalLM.from_pretrained(output, device_map="xpu") + + quantization_config = model_loaded.config.quantization_config.quantization_config + assert quantization_config is not None + + # check config is set properly + assert "lm_head" in quantization_config.ignore + assert len(quantization_config.config_groups) == 1 + quant_scheme = quantization_config.config_groups["group_0"] + assert isinstance(quant_scheme, QuantizationScheme) + + weight_args = quantization_config.config_groups["group_0"].weights + act_args = quantization_config.config_groups["group_0"].input_activations + assert isinstance(weight_args, QuantizationArgs) + assert weight_args.num_bits == recipe.config_groups["group_0"].weights.num_bits + assert weight_args.strategy == recipe.config_groups["group_0"].weights.strategy + if act_args is not None: + assert act_args.num_bits == recipe.config_groups["group_0"].input_activations.num_bits + assert act_args.strategy == recipe.config_groups["group_0"].input_activations.strategy + + # Check a specific layer is quantized + targeted_linear_layer = model_loaded.model.layers[2].self_attn.q_proj + assert hasattr(targeted_linear_layer, "quantization_scheme") + + # Check lm-head is not quantized + not_targeted = model_loaded.lm_head + assert not hasattr(not_targeted, "quantization_scheme")