Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/vulkan/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ runtime.python_library(
],
visibility = [
"//executorch/backends/...",
"//executorch/examples/...",
],
deps = [
":int4_weight_only_quantizer",
Expand Down
1 change: 1 addition & 0 deletions examples/models/llama2/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ runtime.python_library(
deps = [
"//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform",
"//caffe2:torch",
"//executorch/backends/vulkan/_passes:vulkan_passes",
"//executorch/examples/models:model_base",
"//executorch/examples/models:models",
"//executorch/extension/llm/custom_ops:custom_ops_aot_py",
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def build_args_parser() -> argparse.ArgumentParser:
"--quantization_mode",
type=str,
default=None,
choices=["int8", "8da4w", "8da4w-gptq"],
choices=["int8", "8da4w", "8da4w-gptq", "vulkan_4w"],
help="type of quantization",
)

Expand Down
7 changes: 6 additions & 1 deletion examples/models/llama2/source_transformation/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import torch.nn as nn
import torch.nn.functional as F

from executorch.backends.vulkan._passes import VkInt4WeightOnlyQuantizer

from executorch.extension.llm.export.builder import DType

from sentencepiece import SentencePieceProcessor
Expand All @@ -31,7 +33,7 @@
fsLinear = nn.Linear


def quantize(
def quantize( # noqa C901
model: torch.nn.Module,
qmode: str,
activation_dtype: Optional[DType],
Expand Down Expand Up @@ -131,6 +133,9 @@ def quantize(
)
model = gptq_quantizer.quantize(model, inputs)
return model
elif qmode == "vulkan_4w":
model = VkInt4WeightOnlyQuantizer().quantize(model)
return model
else:
raise Exception(f"Unrecognized quantize mode: {qmode}")

Expand Down
3 changes: 0 additions & 3 deletions extension/llm/export/partitioner_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ def get_vulkan_partitioner(
assert (
dtype_override == "fp32" or dtype_override is None
), "Vulkan backend does not support non fp32 dtypes at the moment"
assert (
quantization_mode is None
), "Vulkan backend does not support quantization at the moment"
from executorch.backends.vulkan.partitioner.vulkan_partitioner import (
VulkanPartitioner,
)
Expand Down
Loading