|
29 | 29 | nightly, |
30 | 30 | numpy_cosine_similarity_distance, |
31 | 31 | require_accelerate, |
| 32 | + require_accelerator, |
32 | 33 | require_big_accelerator, |
33 | 34 | require_gguf_version_greater_or_equal, |
34 | 35 | require_peft_backend, |
|
37 | 38 |
|
38 | 39 |
|
39 | 40 | if is_gguf_available(): |
| 41 | + import gguf |
| 42 | + |
40 | 43 | from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter |
41 | 44 |
|
42 | 45 | enable_full_determinism() |
43 | 46 |
|
44 | 47 |
|
| 48 | +@nightly |
| 49 | +@require_accelerate |
| 50 | +@require_accelerator |
| 51 | +@require_gguf_version_greater_or_equal("0.10.0") |
| 52 | +class GGUFCudaKernelsTests(unittest.TestCase): |
| 53 | + def setUp(self): |
| 54 | + gc.collect() |
| 55 | + backend_empty_cache(torch_device) |
| 56 | + |
| 57 | + def tearDown(self): |
| 58 | + gc.collect() |
| 59 | + backend_empty_cache(torch_device) |
| 60 | + |
| 61 | + def test_cuda_kernels_vs_native(self): |
| 62 | + if torch_device != "cuda": |
| 63 | + self.skipTest("CUDA kernels test requires CUDA device") |
| 64 | + |
| 65 | + from diffusers.quantizers.gguf.utils import GGUFLinear, can_use_cuda_kernels |
| 66 | + |
| 67 | + if not can_use_cuda_kernels: |
| 68 | + self.skipTest("CUDA kernels not available (compute capability < 7 or kernels not installed)") |
| 69 | + |
| 70 | + test_quant_types = ["Q4_0", "Q4_K"] |
| 71 | + test_shape = (1, 64, 512) # batch, seq_len, hidden_dim |
| 72 | + compute_dtype = torch.bfloat16 |
| 73 | + |
| 74 | + for quant_type in test_quant_types: |
| 75 | + qtype = getattr(gguf.GGMLQuantizationType, quant_type) |
| 76 | + block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] |
| 77 | + |
| 78 | + in_features, out_features = 512, 512 |
| 79 | + total_elements = in_features * out_features |
| 80 | + n_blocks = total_elements // block_size |
| 81 | + weight_bytes = n_blocks * type_size |
| 82 | + |
| 83 | + torch.manual_seed(42) |
| 84 | + weight_data = torch.randint(0, 256, (weight_bytes,), dtype=torch.uint8, device=torch_device) |
| 85 | + weight = GGUFParameter(weight_data, quant_type=qtype) |
| 86 | + |
| 87 | + x = torch.randn(test_shape, dtype=compute_dtype, device=torch_device) |
| 88 | + |
| 89 | + linear = GGUFLinear(in_features, out_features, bias=True, compute_dtype=compute_dtype) |
| 90 | + linear.weight = weight |
| 91 | + linear.bias = nn.Parameter(torch.randn(out_features, dtype=compute_dtype)) |
| 92 | + linear = linear.to(torch_device) |
| 93 | + |
| 94 | + with torch.no_grad(): |
| 95 | + output_native = linear.forward_native(x) |
| 96 | + output_cuda = linear.forward_cuda(x) |
| 97 | + |
| 98 | + # Compare outputs |
| 99 | + max_diff = torch.abs(output_cuda - output_native).max() |
| 100 | + assert max_diff < 1e-4, "GGUF CUDA Kernel Output is different from Native Output" |
| 101 | + |
| 102 | + |
45 | 103 | @nightly |
46 | 104 | @require_big_accelerator |
47 | 105 | @require_accelerate |
|
0 commit comments