-
Notifications
You must be signed in to change notification settings - Fork 746
Add torchao conversion #14545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add torchao conversion #14545
Changes from 14 commits
0bb87af
fc82a58
e484954
288b86b
96f98b5
4d9e718
61a05bc
1a5c3f3
ffd7c1c
f503d2c
2847aae
2c69cce
f7a6e2e
a7fa6bd
51279a4
9db7b18
96dc88e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -452,6 +452,16 @@ class MPSConfig: | |
| enabled: bool = False | ||
|
|
||
|
|
||
| @dataclass | ||
| class TorchAOKernelsConfig: | ||
| """ | ||
| Configures the torchao-kernels backend. | ||
| """ | ||
|
|
||
| convert_linear: bool = False | ||
| convert_tied_embedding: bool = False | ||
|
||
|
|
||
|
|
||
| @dataclass | ||
| class BackendConfig: | ||
| """ | ||
|
|
@@ -464,6 +474,7 @@ class BackendConfig: | |
| vulkan: VulkanConfig = field(default_factory=VulkanConfig) | ||
| qnn: QNNConfig = field(default_factory=QNNConfig) | ||
| mps: MPSConfig = field(default_factory=MPSConfig) | ||
| torchao: TorchAOKernelsConfig = field(default_factory=TorchAOKernelsConfig) | ||
|
|
||
|
|
||
| ################################################################################ | ||
|
|
@@ -632,6 +643,28 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 | |
| if hasattr(args, "mps"): | ||
| llm_config.backend.mps.enabled = args.mps | ||
|
|
||
| # TorchAoKernels | ||
| if any( | ||
| hasattr(args, a) | ||
| for a in [ | ||
| "torchao_kernels", | ||
| "torchao_kernels_linear", | ||
| "torchao_kernels_tied_embedding", | ||
| ] | ||
| ): | ||
| if hasattr(args, "torchao_kernels") and args.torchao_kernels: | ||
| # Enable all conversions if torchao_kernels is specified | ||
| llm_config.backend.torchao.convert_linear = True | ||
| llm_config.backend.torchao.convert_tied_embedding = True | ||
| else: | ||
| # Otherwise, only enable the conversions that are specified | ||
| llm_config.backend.torchao.convert_linear = getattr( | ||
| args, "torchao_kernels_linear", False | ||
|
||
| ) | ||
| llm_config.backend.torchao.convert_tied_embedding = getattr( | ||
| args, "torchao_kernels_tied_embedding", False | ||
|
||
| ) | ||
|
|
||
| # DebugConfig | ||
| if hasattr(args, "profile_memory"): | ||
| llm_config.debug.profile_memory = args.profile_memory | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we follow the other backend config examples and use
enabled?