@@ -314,81 +314,83 @@ void spmm_coo_very_sparse_naive_int8(
314314#if BUILD_XPU
315315
316316void dequantizeBlockwise_fp16 (
317- float * code, unsigned char * A, float * absmax, sycl::half * out, int blocksize, const int n, sycl::queue* stream
317+ float * code, unsigned char * A, float * absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
318318) {
319319 dequantizeBlockwise<sycl::half, General8bit>(code, A, absmax, out, blocksize, n, stream);
320320}
321321
322322void dequantizeBlockwise_fp16_fp4 (
323- float * code, unsigned char * A, float * absmax, sycl::half * out, int blocksize, const int n, sycl::queue* stream
323+ float * code, unsigned char * A, float * absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
324324) {
325325 dequantizeBlockwise<sycl::half, FP4>(NULL , A, absmax, out, blocksize, n, stream);
326326}
327327
328328void dequantizeBlockwise_fp16_nf4 (
329- float * code, unsigned char * A, float * absmax, sycl::half * out, int blocksize, const int n, sycl::queue* stream
329+ float * code, unsigned char * A, float * absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
330330) {
331331 dequantizeBlockwise<sycl::half, NF4>(NULL , A, absmax, out, blocksize, n, stream);
332332}
333333
334334void dequantizeBlockwise_fp32 (
335- float * code, unsigned char * A, float * absmax, float * out, int blocksize, const int n, sycl::queue* stream
335+ float * code, unsigned char * A, float * absmax, float * out, int blocksize, const int n, sycl::queue* stream
336336) {
337337 dequantizeBlockwise<float , General8bit>(code, A, absmax, out, blocksize, n, stream);
338338}
339339
340340void dequantizeBlockwise_fp32_fp4 (
341- float * code, unsigned char * A, float * absmax, float * out, int blocksize, const int n, sycl::queue* stream
341+ float * code, unsigned char * A, float * absmax, float * out, int blocksize, const int n, sycl::queue* stream
342342) {
343343 dequantizeBlockwise<float , FP4>(NULL , A, absmax, out, blocksize, n, stream);
344344}
345345
346346void dequantizeBlockwise_fp32_nf4 (
347- float * code, unsigned char * A, float * absmax, float * out, int blocksize, const int n, sycl::queue* stream
347+ float * code, unsigned char * A, float * absmax, float * out, int blocksize, const int n, sycl::queue* stream
348348) {
349349 dequantizeBlockwise<float , NF4>(NULL , A, absmax, out, blocksize, n, stream);
350350}
351351
352352void dequantizeBlockwise_bf16 (
353- float * code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 * out, int blocksize, const int n,
353+ float * code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
354354 sycl::queue* stream
355355) {
356356 dequantizeBlockwise<sycl::ext::oneapi::bfloat16, General8bit>(code, A, absmax, out, blocksize, n, stream);
357357}
358358
359359void dequantizeBlockwise_bf16_fp4 (
360- float * code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 * out, int blocksize, const int n,
360+ float * code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
361361 sycl::queue* stream
362362) {
363363 dequantizeBlockwise<sycl::ext::oneapi::bfloat16, FP4>(NULL , A, absmax, out, blocksize, n, stream);
364364}
365365
366366void dequantizeBlockwise_bf16_nf4 (
367- float * code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 * out, int blocksize, const int n,
367+ float * code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
368368 sycl::queue* stream
369- ) {
369+ ) {
370370 dequantizeBlockwise<sycl::ext::oneapi::bfloat16, NF4>(NULL , A, absmax, out, blocksize, n, stream);
371371}
372372
373373void gemv_4bit_inference_fp16 (
374- int m, int n, int k, sycl::half * A, unsigned char * B, float * absmax, float * datatype, sycl::half * out,
375- int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
374+ int m, int n, int k, sycl::half* A, unsigned char * B, float * absmax, float * datatype, sycl::half* out, int lda ,
375+ int ldb, int ldc, int blocksize, sycl::queue* stream
376376) {
377- gemv_4bit_inference<sycl::half, 16 >(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
377+ gemv_4bit_inference<sycl::half, 16 >(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
378378}
379379
380380void gemv_4bit_inference_bf16 (
381- int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char * B, float * absmax, float * datatype,
382- sycl::ext::oneapi::bfloat16 * out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
381+ int m, int n, int k, sycl::ext::oneapi::bfloat16* A, unsigned char * B, float * absmax, float * datatype,
382+ sycl::ext::oneapi::bfloat16* out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
383383) {
384- gemv_4bit_inference<sycl::ext::oneapi::bfloat16, 16 >(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
384+ gemv_4bit_inference<sycl::ext::oneapi::bfloat16, 16 >(
385+ m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream
386+ );
385387}
386388
387389void gemv_4bit_inference_fp32 (
388- int m, int n, int k, float * A, unsigned char * B, float * absmax, float * datatype, float * out, int lda,
389- int ldb, int ldc, int blocksize, sycl::queue* stream
390+ int m, int n, int k, float * A, unsigned char * B, float * absmax, float * datatype, float * out, int lda, int ldb ,
391+ int ldc, int blocksize, sycl::queue* stream
390392) {
391- gemv_4bit_inference<float , 32 >(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
393+ gemv_4bit_inference<float , 32 >(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
392394}
393395
394396#endif
@@ -746,81 +748,81 @@ void cgemm_4bit_inference_naive_fp32(
746748#if BUILD_XPU
747749
748750void cdequantize_blockwise_fp16_fp4 (
749- float * code, unsigned char * A, float * absmax, sycl::half * out, int blocksize, const int n, sycl::queue* stream
751+ float * code, unsigned char * A, float * absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
750752) {
751753 dequantizeBlockwise_fp16_fp4 (code, A, absmax, out, blocksize, n, stream);
752754}
753755
754756void cdequantize_blockwise_fp16 (
755- float * code, unsigned char * A, float * absmax, sycl::half * out, int blocksize, const int n, sycl::queue* stream
757+ float * code, unsigned char * A, float * absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
756758) {
757759 dequantizeBlockwise_fp16 (code, A, absmax, out, blocksize, n, stream);
758760}
759761
760762void cdequantize_blockwise_fp16_nf4 (
761- float * code, unsigned char * A, float * absmax, sycl::half * out, int blocksize, const int n, sycl::queue* stream
763+ float * code, unsigned char * A, float * absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
762764) {
763765 dequantizeBlockwise_fp16_nf4 (code, A, absmax, out, blocksize, n, stream);
764766}
765767
766768void cdequantize_blockwise_fp32 (
767- float * code, unsigned char * A, float * absmax, float * out, int blocksize, const int n, sycl::queue* stream
769+ float * code, unsigned char * A, float * absmax, float * out, int blocksize, const int n, sycl::queue* stream
768770) {
769771 dequantizeBlockwise_fp32 (code, A, absmax, out, blocksize, n, stream);
770772}
771773
772774void cdequantize_blockwise_fp32_fp4 (
773- float * code, unsigned char * A, float * absmax, float * out, int blocksize, const int n, sycl::queue* stream
775+ float * code, unsigned char * A, float * absmax, float * out, int blocksize, const int n, sycl::queue* stream
774776) {
775777 dequantizeBlockwise_fp32_fp4 (code, A, absmax, out, blocksize, n, stream);
776778}
777779
778780void cdequantize_blockwise_fp32_nf4 (
779- float * code, unsigned char * A, float * absmax, float * out, int blocksize, const int n, sycl::queue* stream
781+ float * code, unsigned char * A, float * absmax, float * out, int blocksize, const int n, sycl::queue* stream
780782) {
781783 dequantizeBlockwise_fp32_nf4 (code, A, absmax, out, blocksize, n, stream);
782784}
783785
784786void cdequantize_blockwise_bf16 (
785- float * code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 * out, int blocksize, const int n,
787+ float * code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
786788 sycl::queue* stream
787789) {
788790 dequantizeBlockwise_bf16 (code, A, absmax, out, blocksize, n, stream);
789791}
790792
791793void cdequantize_blockwise_bf16_fp4 (
792- float * code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 * out, int blocksize, const int n,
794+ float * code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
793795 sycl::queue* stream
794796) {
795797 dequantizeBlockwise_bf16_fp4 (code, A, absmax, out, blocksize, n, stream);
796798}
797799
798800void cdequantize_blockwise_bf16_nf4 (
799- float * code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 * out, int blocksize, const int n,
801+ float * code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
800802 sycl::queue* stream
801803) {
802804 dequantizeBlockwise_bf16_nf4 (code, A, absmax, out, blocksize, n, stream);
803805}
804806
805807void cgemv_4bit_inference_fp16 (
806- int m, int n, int k, sycl::half * A, unsigned char * B, float * absmax, float * datatype, sycl::half * out,
807- int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
808+ int m, int n, int k, sycl::half* A, unsigned char * B, float * absmax, float * datatype, sycl::half* out, int lda ,
809+ int ldb, int ldc, int blocksize, sycl::queue* stream
808810) {
809- gemv_4bit_inference_fp16 (m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
811+ gemv_4bit_inference_fp16 (m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
810812}
811813
812814void cgemv_4bit_inference_bf16 (
813- int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char * B, float * absmax, float * datatype,
814- sycl::ext::oneapi::bfloat16 * out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
815+ int m, int n, int k, sycl::ext::oneapi::bfloat16* A, unsigned char * B, float * absmax, float * datatype,
816+ sycl::ext::oneapi::bfloat16* out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
815817) {
816- gemv_4bit_inference_bf16 (m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
818+ gemv_4bit_inference_bf16 (m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
817819}
818820
819821void cgemv_4bit_inference_fp32 (
820- int m, int n, int k, float * A, unsigned char * B, float * absmax, float * datatype, float * out, int lda,
821- int ldb, int ldc, int blocksize, sycl::queue* stream
822+ int m, int n, int k, float * A, unsigned char * B, float * absmax, float * datatype, float * out, int lda, int ldb ,
823+ int ldc, int blocksize, sycl::queue* stream
822824) {
823- gemv_4bit_inference_fp32 (m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
825+ gemv_4bit_inference_fp32 (m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
824826}
825827
826828#endif
0 commit comments