Skip to content

Commit 24150d0

Browse files
authored
TST Enable BNB tests on XPU (#2396)
1 parent 461f642 commit 24150d0

File tree

8 files changed

+97
-105
lines changed

8 files changed

+97
-105
lines changed

src/peft/utils/integrations.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import transformers
2424
from torch import nn
2525

26+
from peft.import_utils import is_xpu_available
27+
2628

2729
@contextmanager
2830
def gather_params_ctx(param, modifier_rank: int = 0, fwd_module: torch.nn.Module = None):
@@ -90,8 +92,11 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
9092
# BNB requires CUDA weights
9193
device = weight.device
9294
is_cpu = device.type == torch.device("cpu").type
93-
if is_cpu and torch.cuda.is_available():
94-
weight = weight.to(torch.device("cuda"))
95+
if is_cpu:
96+
if torch.cuda.is_available():
97+
weight = weight.to(torch.device("cuda"))
98+
elif is_xpu_available():
99+
weight = weight.to(torch.device("xpu"))
95100

96101
cls_name = weight.__class__.__name__
97102
if cls_name == "Params4bit":

src/peft/utils/loftq_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@
2222
from typing import Callable, Optional, Union
2323

2424
import torch
25+
from accelerate.utils.memory import clear_device_cache
2526
from huggingface_hub import snapshot_download
2627
from huggingface_hub.errors import HFValidationError, LocalEntryNotFoundError
2728
from safetensors import SafetensorError, safe_open
2829
from transformers.utils import cached_file
2930
from transformers.utils.hub import get_checkpoint_shard_files
3031

31-
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
32+
from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_xpu_available
3233

3334

3435
class NFQuantizer:
@@ -201,20 +202,19 @@ def loftq_init(weight: Union[torch.Tensor, torch.nn.Parameter], num_bits: int, r
201202
out_feature, in_feature = weight.size()
202203
device = weight.device
203204
dtype = weight.dtype
204-
205205
logging.info(
206206
f"Weight: ({out_feature}, {in_feature}) | Rank: {reduced_rank} | Num Iter: {num_iter} | Num Bits: {num_bits}"
207207
)
208208
if not is_bnb_4bit_available() or num_bits in [2, 8]:
209209
quantizer = NFQuantizer(num_bits=num_bits, device=device, method="normal", block_size=64)
210210
compute_device = device
211211
else:
212-
compute_device = "cuda"
212+
compute_device = "xpu" if is_xpu_available() else "cuda"
213213

214214
weight = weight.to(device=compute_device, dtype=torch.float32)
215215
res = weight.clone()
216216
for i in range(num_iter):
217-
torch.cuda.empty_cache()
217+
clear_device_cache()
218218
# Quantization
219219
if num_bits == 4 and is_bnb_4bit_available():
220220
qweight = bnb.nn.Params4bit(
@@ -246,12 +246,12 @@ def _loftq_init_new(qweight, weight, num_bits: int, reduced_rank: int):
246246
if not is_bnb_4bit_available():
247247
raise ValueError("bitsandbytes 4bit quantization is not available.")
248248

249-
compute_device = "cuda"
249+
compute_device = "xpu" if is_xpu_available() else "cuda"
250250
dequantized_weight = bnb.functional.dequantize_4bit(qweight.data, qweight.quant_state)
251251

252252
weight = weight.to(device=compute_device, dtype=torch.float32)
253253
residual = weight - dequantized_weight
254-
torch.cuda.empty_cache()
254+
clear_device_cache()
255255
# Decompose the residualidual by SVD
256256
output = _low_rank_decomposition(residual, reduced_rank=reduced_rank)
257257
L, R, reduced_rank = output["L"], output["R"], output["reduced_rank"]

tests/bnb/test_bnb_regression.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@
2727
import torch
2828
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, BitsAndBytesConfig
2929

30+
from peft.import_utils import is_xpu_available
31+
3032

3133
bnb = pytest.importorskip("bitsandbytes")
3234

33-
device = torch.device("cuda")
35+
device = torch.device("xpu") if is_xpu_available() else torch.device("cuda")
3436

3537

3638
def bytes_from_tensor(x):
@@ -47,7 +49,7 @@ def bytes_from_tensor(x):
4749
############
4850

4951

50-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
52+
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
5153
def test_opt_350m_4bit():
5254
torch.manual_seed(0)
5355
bnb_config = BitsAndBytesConfig(
@@ -70,7 +72,7 @@ def test_opt_350m_4bit():
7072
torch.testing.assert_allclose(output, expected)
7173

7274

73-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
75+
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
7476
def test_opt_350m_8bit():
7577
torch.manual_seed(0)
7678
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
@@ -89,7 +91,7 @@ def test_opt_350m_8bit():
8991
torch.testing.assert_allclose(output, expected)
9092

9193

92-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
94+
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
9395
def test_opt_350m_4bit_double_quant():
9496
torch.manual_seed(0)
9597
bnb_config = BitsAndBytesConfig(
@@ -112,7 +114,7 @@ def test_opt_350m_4bit_double_quant():
112114
torch.testing.assert_allclose(output, expected)
113115

114116

115-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
117+
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
116118
def test_opt_350m_4bit_compute_dtype_float16():
117119
torch.manual_seed(0)
118120
bnb_config = BitsAndBytesConfig(
@@ -135,7 +137,7 @@ def test_opt_350m_4bit_compute_dtype_float16():
135137
torch.testing.assert_allclose(output, expected)
136138

137139

138-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
140+
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
139141
def test_opt_350m_4bit_quant_type_nf4():
140142
torch.manual_seed(0)
141143
bnb_config = BitsAndBytesConfig(
@@ -159,7 +161,7 @@ def test_opt_350m_4bit_quant_type_nf4():
159161
torch.testing.assert_allclose(output, expected)
160162

161163

162-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
164+
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
163165
def test_opt_350m_4bit_quant_storage():
164166
# note: using torch.float32 instead of the default torch.uint8 does not seem to affect the result
165167
torch.manual_seed(0)
@@ -184,7 +186,7 @@ def test_opt_350m_4bit_quant_storage():
184186
torch.testing.assert_allclose(output, expected)
185187

186188

187-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
189+
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
188190
def test_opt_350m_8bit_threshold():
189191
torch.manual_seed(0)
190192
bnb_config = BitsAndBytesConfig(
@@ -211,7 +213,7 @@ def test_opt_350m_8bit_threshold():
211213
###########
212214

213215

214-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
216+
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
215217
def test_flan_t5_4bit():
216218
torch.manual_seed(0)
217219
bnb_config = BitsAndBytesConfig(
@@ -235,7 +237,7 @@ def test_flan_t5_4bit():
235237
torch.testing.assert_allclose(output, expected)
236238

237239

238-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
240+
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
239241
@pytest.mark.xfail # might not be reproducible depending on hardware
240242
def test_flan_t5_8bit():
241243
torch.manual_seed(0)

tests/test_common_gpu.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@
6363
require_multi_accelerator,
6464
require_non_cpu,
6565
require_torch_gpu,
66-
require_torch_multi_gpu,
6766
)
6867

6968

@@ -594,7 +593,7 @@ def test_lora_causal_lm_multi_gpu_inference(self):
594593
# this should work without any problem
595594
_ = model.generate(input_ids=input_ids)
596595

597-
@require_torch_multi_gpu
596+
@require_multi_accelerator
598597
@pytest.mark.multi_gpu_tests
599598
@require_bitsandbytes
600599
def test_lora_seq2seq_lm_multi_gpu_inference(self):
@@ -622,7 +621,7 @@ def test_lora_seq2seq_lm_multi_gpu_inference(self):
622621
# this should work without any problem
623622
_ = model.generate(input_ids=input_ids)
624623

625-
@require_torch_multi_gpu
624+
@require_multi_accelerator
626625
@pytest.mark.multi_gpu_tests
627626
@require_bitsandbytes
628627
def test_adaption_prompt_8bit(self):
@@ -645,7 +644,7 @@ def test_adaption_prompt_8bit(self):
645644
random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device)
646645
_ = model(random_input)
647646

648-
@require_torch_multi_gpu
647+
@require_multi_accelerator
649648
@pytest.mark.multi_gpu_tests
650649
@require_bitsandbytes
651650
def test_adaption_prompt_4bit(self):
@@ -668,7 +667,7 @@ def test_adaption_prompt_4bit(self):
668667
random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device)
669668
_ = model(random_input)
670669

671-
@require_torch_gpu
670+
@require_non_cpu
672671
@pytest.mark.single_gpu_tests
673672
@require_bitsandbytes
674673
def test_print_4bit_expected(self):
@@ -778,7 +777,7 @@ def test_8bit_merge_lora(self):
778777
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, bnb.nn.Linear8bitLt)
779778
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, bnb.nn.Linear8bitLt)
780779

781-
@require_torch_gpu
780+
@require_non_cpu
782781
@pytest.mark.single_gpu_tests
783782
@require_bitsandbytes
784783
def test_8bit_merge_and_disable_lora(self):
@@ -814,7 +813,7 @@ def test_8bit_merge_and_disable_lora(self):
814813
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, LoraLinear8bitLt)
815814
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear8bitLt)
816815

817-
@require_torch_gpu
816+
@require_non_cpu
818817
@pytest.mark.single_gpu_tests
819818
@require_bitsandbytes
820819
def test_8bit_merge_lora_with_bias(self):
@@ -846,7 +845,7 @@ def test_8bit_merge_lora_with_bias(self):
846845
assert not torch.allclose(out_base, out_before_merge, atol=atol, rtol=rtol)
847846
assert torch.allclose(out_before_merge, out_after_merge, atol=atol, rtol=rtol)
848847

849-
@require_torch_gpu
848+
@require_non_cpu
850849
@pytest.mark.single_gpu_tests
851850
@require_bitsandbytes
852851
def test_4bit_merge_lora(self):
@@ -888,7 +887,7 @@ def test_4bit_merge_lora(self):
888887
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, bnb.nn.Linear4bit)
889888
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, bnb.nn.Linear4bit)
890889

891-
@require_torch_gpu
890+
@require_non_cpu
892891
@pytest.mark.single_gpu_tests
893892
@require_bitsandbytes
894893
def test_4bit_merge_and_disable_lora(self):
@@ -930,7 +929,7 @@ def test_4bit_merge_and_disable_lora(self):
930929
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, LoraLinear4bit)
931930
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear4bit)
932931

933-
@require_torch_gpu
932+
@require_non_cpu
934933
@pytest.mark.single_gpu_tests
935934
@require_bitsandbytes
936935
def test_4bit_merge_lora_with_bias(self):
@@ -971,7 +970,7 @@ def test_4bit_merge_lora_with_bias(self):
971970
assert not torch.allclose(out_base, out_before_merge, atol=atol, rtol=rtol)
972971
assert torch.allclose(out_before_merge, out_after_merge, atol=atol, rtol=rtol)
973972

974-
@require_torch_gpu
973+
@require_non_cpu
975974
@pytest.mark.single_gpu_tests
976975
@require_bitsandbytes
977976
def test_4bit_lora_mixed_adapter_batches_lora(self):
@@ -1042,7 +1041,7 @@ def test_4bit_lora_mixed_adapter_batches_lora(self):
10421041
assert torch.allclose(out_adapter0[1::3], out_mixed[1::3], atol=atol, rtol=rtol)
10431042
assert torch.allclose(out_adapter1[2::3], out_mixed[2::3], atol=atol, rtol=rtol)
10441043

1045-
@require_torch_gpu
1044+
@require_non_cpu
10461045
@pytest.mark.single_gpu_tests
10471046
@require_bitsandbytes
10481047
def test_8bit_lora_mixed_adapter_batches_lora(self):
@@ -1124,7 +1123,7 @@ def test_serialization_shared_tensors(self):
11241123
with tempfile.TemporaryDirectory() as tmp_dir:
11251124
model.save_pretrained(tmp_dir, safe_serialization=True)
11261125

1127-
@require_torch_gpu
1126+
@require_non_cpu
11281127
@pytest.mark.single_gpu_tests
11291128
@require_bitsandbytes
11301129
def test_4bit_dora_inference(self):
@@ -1163,7 +1162,7 @@ def test_4bit_dora_inference(self):
11631162
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, LoraLinear4bit)
11641163
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear4bit)
11651164

1166-
@require_torch_gpu
1165+
@require_non_cpu
11671166
@pytest.mark.single_gpu_tests
11681167
@require_bitsandbytes
11691168
def test_8bit_dora_inference(self):
@@ -1197,7 +1196,7 @@ def test_8bit_dora_inference(self):
11971196
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, LoraLinear8bitLt)
11981197
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear8bitLt)
11991198

1200-
@require_torch_gpu
1199+
@require_non_cpu
12011200
@pytest.mark.single_gpu_tests
12021201
@require_bitsandbytes
12031202
def test_4bit_dora_merging(self):

0 commit comments

Comments
 (0)