File tree Expand file tree Collapse file tree 2 files changed +10
-10
lines changed
bitsandbytes/backends/npu Expand file tree Collapse file tree 2 files changed +10
-10
lines changed Original file line number Diff line number Diff line change @@ -308,11 +308,7 @@ endif()
308308
309309set_source_files_properties (${CPP_FILES} PROPERTIES LANGUAGE CXX)
310310add_library (bitsandbytes SHARED ${SRC_FILES} )
311- if (BUILD_NPU)
312- target_compile_features (bitsandbytes PUBLIC cxx_std_17)
313- else ()
314- target_compile_features (bitsandbytes PUBLIC cxx_std_14)
315- endif ()
311+ target_compile_features (bitsandbytes PUBLIC cxx_std_17)
316312target_include_directories (bitsandbytes PUBLIC csrc include )
317313
318314
Original file line number Diff line number Diff line change @@ -99,21 +99,25 @@ def _dequantize_4bit_impl(
9999 dtype in [torch .bfloat16 , torch .float16 , torch .float32 ],
100100 lambda : f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got { dtype } " ,
101101 )
102+ if out .dtype == torch .bfloat16 :
103+ # bf16: bf16 -> fp32 -> op -> fp32 -> bf16
104+ absmax = absmax .to (torch .float32 )
105+ out_fp32 = torch .empty (out .shape , dtype = torch .float32 , device = out .device )
106+ else :
107+ out_fp32 = out
108+
102109 args = (
103110 get_ptr (A ),
104111 get_ptr (absmax ),
105- get_ptr (out ),
112+ get_ptr (out_fp32 ),
106113 ct .c_int (blocksize ),
107114 ct .c_int (out .numel ()),
108115 torch .npu .current_stream (),
109116 )
110117
111118 if out .dtype == torch .bfloat16 :
112- # bf16: bf16 -> fp32 -> op -> fp32 -> bf16
113- absmax = absmax .to (torch .float32 )
114- out = out .to (torch .float32 )
115119 lib .cdequantize_blockwise_fp32_nf4 (* args )
116- out = out . to (torch .bfloat16 )
120+ out . copy_ ( out_fp32 . to (torch .bfloat16 ) )
117121 elif out .dtype == torch .float16 :
118122 lib .cdequantize_blockwise_fp16_nf4 (* args )
119123 elif out .dtype == torch .float32 :
You can’t perform that action at this time.
0 commit comments