@@ -73,15 +73,12 @@ def test_cuda_kernels_vs_native(self):
7373
7474 for quant_type in test_quant_types :
7575 qtype = getattr (gguf .GGMLQuantizationType , quant_type )
76- block_size , type_size = gguf .GGML_QUANT_SIZES [qtype ]
77-
7876 in_features , out_features = 512 , 512
79- total_elements = in_features * out_features
80- n_blocks = total_elements // block_size
81- weight_bytes = n_blocks * type_size
8277
8378 torch .manual_seed (42 )
84- weight_data = torch .randint (0 , 256 , (weight_bytes ,), dtype = torch .uint8 , device = torch_device )
79+ float_weight = torch .randn (out_features , in_features , dtype = torch .float32 )
80+ quantized_data = gguf .quants .quantize (float_weight .numpy (), qtype )
81+ weight_data = torch .from_numpy (quantized_data ).to (device = torch_device )
8582 weight = GGUFParameter (weight_data , quant_type = qtype )
8683
8784 x = torch .randn (test_shape , dtype = compute_dtype , device = torch_device )
@@ -95,9 +92,9 @@ def test_cuda_kernels_vs_native(self):
9592 output_native = linear .forward_native (x )
9693 output_cuda = linear .forward_cuda (x )
9794
98- # Compare outputs
99- max_diff = torch . abs ( output_cuda - output_native ). max ()
100- assert max_diff < 1e-4 , "GGUF CUDA Kernel Output is different from Native Output"
95+ assert torch . allclose ( output_native , output_cuda , 1e-2 ), (
96+ f"GGUF CUDA Kernel Output is different from Native Output for { quant_type } "
97+ )
10198
10299
103100@nightly
0 commit comments