diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index 0b7a756bf0d..812c39c2b64 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -62,6 +62,7 @@ runtime.python_library( ], visibility = [ "//executorch/backends/...", + "//executorch/examples/...", ], deps = [ ":int4_weight_only_quantizer", diff --git a/examples/models/llama2/TARGETS b/examples/models/llama2/TARGETS index 17597f3d502..6d806cbc0d6 100644 --- a/examples/models/llama2/TARGETS +++ b/examples/models/llama2/TARGETS @@ -103,6 +103,7 @@ runtime.python_library( deps = [ "//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform", "//caffe2:torch", + "//executorch/backends/vulkan/_passes:vulkan_passes", "//executorch/examples/models:model_base", "//executorch/examples/models:models", "//executorch/extension/llm/custom_ops:custom_ops_aot_py", diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 8cff6e8e11a..0a3c7620cb6 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -157,7 +157,7 @@ def build_args_parser() -> argparse.ArgumentParser: "--quantization_mode", type=str, default=None, - choices=["int8", "8da4w", "8da4w-gptq"], + choices=["int8", "8da4w", "8da4w-gptq", "vulkan_4w"], help="type of quantization", ) diff --git a/examples/models/llama2/source_transformation/quantize.py b/examples/models/llama2/source_transformation/quantize.py index 3879c00fd91..274fc447b36 100644 --- a/examples/models/llama2/source_transformation/quantize.py +++ b/examples/models/llama2/source_transformation/quantize.py @@ -12,6 +12,8 @@ import torch.nn as nn import torch.nn.functional as F +from executorch.backends.vulkan._passes import VkInt4WeightOnlyQuantizer + from executorch.extension.llm.export.builder import DType from sentencepiece import SentencePieceProcessor @@ -31,7 +33,7 @@ fsLinear = nn.Linear -def quantize( +def quantize( # noqa C901 model: torch.nn.Module, qmode: str, activation_dtype: Optional[DType], @@ -131,6 +133,9 @@ def quantize( ) model = gptq_quantizer.quantize(model, inputs) return model + elif qmode == "vulkan_4w": + model = VkInt4WeightOnlyQuantizer().quantize(model) + return model else: raise Exception(f"Unrecognized quantize mode: {qmode}") diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index 87feb080b4a..d966de9a251 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -37,9 +37,6 @@ def get_vulkan_partitioner( assert ( dtype_override == "fp32" or dtype_override is None ), "Vulkan backend does not support non fp32 dtypes at the moment" - assert ( - quantization_mode is None - ), "Vulkan backend does not support quantization at the moment" from executorch.backends.vulkan.partitioner.vulkan_partitioner import ( VulkanPartitioner, )