@@ -73,15 +73,12 @@ def test_cuda_kernels_vs_native(self):
7373
7474        for  quant_type  in  test_quant_types :
7575            qtype  =  getattr (gguf .GGMLQuantizationType , quant_type )
76-             block_size , type_size  =  gguf .GGML_QUANT_SIZES [qtype ]
77- 
7876            in_features , out_features  =  512 , 512 
79-             total_elements  =  in_features  *  out_features 
80-             n_blocks  =  total_elements  //  block_size 
81-             weight_bytes  =  n_blocks  *  type_size 
8277
8378            torch .manual_seed (42 )
84-             weight_data  =  torch .randint (0 , 256 , (weight_bytes ,), dtype = torch .uint8 , device = torch_device )
79+             float_weight  =  torch .randn (out_features , in_features , dtype = torch .float32 )
80+             quantized_data  =  gguf .quants .quantize (float_weight .numpy (), qtype )
81+             weight_data  =  torch .from_numpy (quantized_data ).to (device = torch_device )
8582            weight  =  GGUFParameter (weight_data , quant_type = qtype )
8683
8784            x  =  torch .randn (test_shape , dtype = compute_dtype , device = torch_device )
@@ -95,9 +92,9 @@ def test_cuda_kernels_vs_native(self):
9592                output_native  =  linear .forward_native (x )
9693                output_cuda  =  linear .forward_cuda (x )
9794
98-             # Compare outputs 
99-             max_diff   =   torch . abs ( output_cuda   -   output_native ). max () 
100-             assert   max_diff   <   1e-4 ,  "GGUF CUDA Kernel Output is different from Native Output" 
95+             assert   torch . allclose ( output_native ,  output_cuda ,  1e-2 ), ( 
96+                  f"GGUF CUDA Kernel Output is different from Native Output for  { quant_type } " 
97+             ) 
10198
10299
103100@nightly  
0 commit comments