Skip to content

Commit 9a886d2

Browse files
navsudfacebook-github-bot
authored andcommitted
Enable quantization for bf16 model (#14558)
Summary: To save GPU memory `bfloat16` dtype is commonly used for training of LLMs. Currently, the quantizer ignores quantizing the nodes if they are not float32. This change enables quantization of bf16 nodes as well. Reviewed By: billmguo Differential Revision: D82866443
1 parent c98079a commit 9a886d2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

backends/qualcomm/quantizer/annotators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _is_float_tensor(node: Node):
6868
or not isinstance(node.meta["val"], FakeTensor)
6969
):
7070
return False
71-
return node.meta["val"].dtype == torch.float32
71+
return node.meta["val"].dtype in (torch.bfloat16, torch.float32)
7272

7373

7474
def _mark_nodes_as_annotated(nodes: List[Node]):

0 commit comments

Comments
 (0)