Skip to content

Commit 074ab44

Browse files
authored
Merge pull request #1 from bhargaveede/HPU_Backend
Added HPU as new backend
2 parents cd73601 + 568d5ac commit 074ab44

File tree

2 files changed

+159
-0
lines changed

2 files changed

+159
-0
lines changed

bitsandbytes/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,23 @@
2828
"cuda", # includes ROCm
2929
"xpu", # Intel GPU
3030
"cpu",
31+
"hpu",
3132
}
3233

3334
# Always register the CPU backend.
3435
register_backend("cpu", CPUBackend())
3536

37+
# Register HPU Backend, if available
38+
try:
39+
import habana_frameworks.torch
40+
41+
if hasattr(torch, "hpu") and torch.hpu.is_available():
42+
from .backends.hpu import HPUBackend
43+
44+
register_backend("hpu", HPUBackend())
45+
except ImportError:
46+
print("Unable to register HPU")
47+
3648
# Register either CUDA or ROCm backend, if available.
3749
# Only one of these backends can be used at a time, since the torch.device semantics are
3850
# the same for both torch+rocm and torch+cuda (e.g. device name is "cuda")

bitsandbytes/backends/hpu.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from typing import Literal, Optional, Tuple, Union
2+
3+
import torch
4+
5+
from bitsandbytes.utils import QuantState
6+
7+
from .base import Backend
8+
from .cpu_xpu_common import (
9+
dequantize_4bit_impl,
10+
gemm_4bit_impl,
11+
igemmlt_impl,
12+
mm_dequant_impl,
13+
quantize_4bit_impl,
14+
)
15+
16+
Tensor = torch.Tensor
17+
18+
19+
class HPUBackend(Backend):
20+
mm_dequant_compute_dtype = torch.bfloat16
21+
mm_dequant_output_dtype = torch.bfloat16
22+
23+
def transform(
24+
self,
25+
A: torch.Tensor,
26+
to_order: str,
27+
from_order="row",
28+
out: Optional[torch.Tensor] = None,
29+
transpose=False,
30+
state: Optional[Tuple[torch.Size, str]] = None,
31+
ld=None,
32+
):
33+
"""
34+
Transform tensor A to to_order. It is originally designed for CUDA.
35+
For HPU, it returns the original tensor if transpose=False.
36+
Otherwise, it returns the transpose of A
37+
"""
38+
if transpose:
39+
if out is not None:
40+
out.copy_(A.T)
41+
else:
42+
out = A.T
43+
else:
44+
if out is not None:
45+
out.copy_(A)
46+
else:
47+
out = A
48+
return out, state
49+
50+
def igemmlt(
51+
self,
52+
A: torch.Tensor,
53+
B: torch.Tensor,
54+
SA: Tuple[torch.Size, str],
55+
SB: Tuple[torch.Size, str],
56+
out: Optional[torch.Tensor] = None,
57+
Sout: Optional[Tuple[torch.Size, str]] = None,
58+
dtype=torch.int32,
59+
) -> Union[torch.Tensor, Tuple[Optional[Tuple[torch.Tensor, Tuple[torch.Size,
60+
str]]]]]:
61+
62+
return igemmlt_impl(A, B, SA, SB, out, Sout, dtype)
63+
64+
def mm_dequant(
65+
self,
66+
A: torch.Tensor,
67+
quant_state: Tuple[torch.Size, str],
68+
row_stats: torch.Tensor,
69+
col_stats: torch.Tensor,
70+
out: Optional[torch.Tensor] = None,
71+
new_row_stats: Optional[torch.Tensor] = None,
72+
new_col_stats: Optional[torch.Tensor] = None,
73+
bias: Optional[torch.Tensor] = None,
74+
) -> torch.Tensor:
75+
76+
return mm_dequant_impl(
77+
A,
78+
quant_state,
79+
row_stats,
80+
col_stats,
81+
out,
82+
new_row_stats,
83+
new_col_stats,
84+
bias,
85+
self.mm_dequant_compute_dtype,
86+
self.mm_dequant_output_dtype,
87+
)
88+
89+
def extract_outliers(
90+
self,
91+
A: torch.Tensor,
92+
SA: Tuple[torch.Size, str],
93+
idx: torch.Tensor,
94+
) -> torch.Tensor:
95+
"""
96+
Extract columns of A by idx
97+
"""
98+
99+
return A[:, idx].contiguous()
100+
101+
def quantize_4bit(
102+
self,
103+
A: torch.Tensor,
104+
absmax: Optional[torch.Tensor] = None,
105+
out: Optional[torch.Tensor] = None,
106+
blocksize=64,
107+
compress_statistics=False,
108+
quant_type: Literal["fp4", "nf4"] = "fp4",
109+
quant_storage=torch.uint8,
110+
) -> Tuple[torch.Tensor, QuantState]:
111+
112+
if blocksize is None:
113+
blocksize = 64
114+
assert quant_storage == torch.uint8
115+
return quantize_4bit_impl(
116+
A, absmax, out, blocksize, compress_statistics, quant_type)
117+
118+
def dequantize_4bit(
119+
self,
120+
A: torch.Tensor,
121+
quant_state: Optional[QuantState] = None,
122+
absmax: Optional[torch.Tensor] = None,
123+
out: Optional[torch.Tensor] = None,
124+
blocksize: int = 64,
125+
quant_type: Literal["fp4", "nf4"] = "fp4",
126+
) -> torch.Tensor:
127+
128+
if blocksize is None:
129+
blocksize = 64
130+
return dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type)
131+
132+
def gemv_4bit(
133+
self,
134+
A: torch.Tensor,
135+
B: torch.Tensor,
136+
out: Optional[torch.Tensor] = None,
137+
transposed_A=False,
138+
transposed_B=False,
139+
state: QuantState = None,
140+
) -> torch.Tensor:
141+
142+
if state is None:
143+
raise ValueError(
144+
"state cannot be None. gemv_4bit() requires the state from quantize_4bit()"
145+
)
146+
147+
return gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state)

0 commit comments

Comments
 (0)